From 9e728a86f5241da24afb6665812e81b64788d63c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 29 May 2025 21:04:48 -0400 Subject: [PATCH 01/13] raising re --- enzyme/Enzyme/Enzyme.cpp | 3576 +------------------------------------- 1 file changed, 38 insertions(+), 3538 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index d61a8430a8f4..3b99de1665f2 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -73,3191 +73,40 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "ActivityAnalysis.h" -#include "DiffeGradientUtils.h" -#include "EnzymeLogic.h" -#include "GradientUtils.h" -#include "TraceInterface.h" -#include "TraceUtils.h" -#include "Utils.h" - -#include "InstructionBatcher.h" - -#include "llvm/Transforms/Utils.h" - -#include "llvm/Transforms/IPO/Attributor.h" -#include "llvm/Transforms/IPO/OpenMPOpt.h" -#include "llvm/Transforms/Utils/Mem2Reg.h" - -#include "BlasAttributor.inc" - -#include "CApi.h" -using namespace llvm; -#ifdef DEBUG_TYPE -#undef DEBUG_TYPE -#endif -#define DEBUG_TYPE "lower-enzyme-intrinsic" - -llvm::cl::opt EnzymeEnable("enzyme-enable", cl::init(true), cl::Hidden, - cl::desc("Run the Enzyme pass")); - -llvm::cl::opt - EnzymePostOpt("enzyme-postopt", cl::init(false), cl::Hidden, - cl::desc("Run enzymepostprocessing optimizations")); - -llvm::cl::opt EnzymeAttributor("enzyme-attributor", cl::init(false), - cl::Hidden, - cl::desc("Run attributor post Enzyme")); - -llvm::cl::opt EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden, - cl::desc("Whether to enable openmp opt")); - -llvm::cl::opt EnzymeTruncateAll( - "enzyme-truncate-all", cl::init(""), cl::Hidden, - cl::desc( - "Truncate all floating point operations. " - "E.g. \"64to32\" or \"64to-\".")); - -#define addAttribute addAttributeAtIndex -#define getAttribute getAttributeAtIndex -bool attributeKnownFunctions(llvm::Function &F) { - bool changed = false; - if (F.getName() == "fprintf") { - for (auto &arg : F.args()) { - if (arg.getType()->isPointerTy()) { - addFunctionNoCapture(&F, arg.getArgNo()); - changed = true; - } - } - } - if (F.getName().contains("__enzyme_float") || - F.getName().contains("__enzyme_double") || - F.getName().contains("__enzyme_integer") || - F.getName().contains("__enzyme_pointer") || - F.getName().contains("__enzyme_todense") || - F.getName().contains("__enzyme_iter") || - F.getName().contains("__enzyme_virtualreverse")) { - changed = true; -#if LLVM_VERSION_MAJOR >= 16 - F.setOnlyReadsMemory(); - F.setOnlyWritesMemory(); -#else - F.addFnAttr(Attribute::ReadNone); -#endif - if (!F.getName().contains("__enzyme_todense")) - for (auto &arg : F.args()) { - if (arg.getType()->isPointerTy()) { - arg.addAttr(Attribute::ReadNone); - addFunctionNoCapture(&F, arg.getArgNo()); - } - } - } - if (F.getName() == "memcmp") { - changed = true; -#if LLVM_VERSION_MAJOR >= 16 - F.setOnlyAccessesArgMemory(); - F.setOnlyReadsMemory(); -#else - F.addFnAttr(Attribute::ArgMemOnly); - F.addFnAttr(Attribute::ReadOnly); -#endif - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); - for (int i = 0; i < 2; i++) - if (F.getFunctionType()->getParamType(i)->isPointerTy()) { - addFunctionNoCapture(&F, i); - F.addParamAttr(i, Attribute::ReadOnly); - } - } - - if (F.getName() == - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_createERmm") { - changed = true; - F.addFnAttr(Attribute::NoFree); - } - if (F.getName() == "MPI_Irecv" || F.getName() == "PMPI_Irecv") { - changed = true; -#if LLVM_VERSION_MAJOR >= 16 - F.setOnlyAccessesInaccessibleMemOrArgMem(); -#else - F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); -#endif - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); - F.addParamAttr(0, Attribute::WriteOnly); - if (F.getFunctionType()->getParamType(2)->isPointerTy()) { - addFunctionNoCapture(&F, 2); - F.addParamAttr(2, Attribute::WriteOnly); - } - F.addParamAttr(6, Attribute::WriteOnly); - } - if (F.getName() == "MPI_Isend" || F.getName() == "PMPI_Isend") { - changed = true; -#if LLVM_VERSION_MAJOR >= 16 - F.setOnlyAccessesInaccessibleMemOrArgMem(); -#else - F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); -#endif - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); - F.addParamAttr(0, Attribute::ReadOnly); - if (F.getFunctionType()->getParamType(2)->isPointerTy()) { - addFunctionNoCapture(&F, 2); - F.addParamAttr(2, Attribute::ReadOnly); - } - F.addParamAttr(6, Attribute::WriteOnly); - } - if (F.getName() == "MPI_Comm_rank" || F.getName() == "PMPI_Comm_rank" || - F.getName() == "MPI_Comm_size" || F.getName() == "PMPI_Comm_size") { - changed = true; -#if LLVM_VERSION_MAJOR >= 16 - F.setOnlyAccessesInaccessibleMemOrArgMem(); -#else - F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); -#endif - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); - - if (F.getFunctionType()->getParamType(0)->isPointerTy()) { - addFunctionNoCapture(&F, 0); - F.addParamAttr(0, Attribute::ReadOnly); - } - if (F.getFunctionType()->getParamType(1)->isPointerTy()) { - F.addParamAttr(1, Attribute::WriteOnly); - addFunctionNoCapture(&F, 1); - } - } - if (F.getName() == "MPI_Wait" || F.getName() == "PMPI_Wait") { - changed = true; - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); - addFunctionNoCapture(&F, 0); - F.addParamAttr(1, Attribute::WriteOnly); - addFunctionNoCapture(&F, 1); - } - if (F.getName() == "MPI_Waitall" || F.getName() == "PMPI_Waitall") { - changed = true; - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); - addFunctionNoCapture(&F, 1); - F.addParamAttr(2, Attribute::WriteOnly); - addFunctionNoCapture(&F, 2); - } - // Map of MPI function name to the arg index of its type argument - std::map MPI_TYPE_ARGS = { - {"MPI_Send", 2}, {"MPI_Ssend", 2}, {"MPI_Bsend", 2}, - {"MPI_Recv", 2}, {"MPI_Brecv", 2}, {"PMPI_Send", 2}, - {"PMPI_Ssend", 2}, {"PMPI_Bsend", 2}, {"PMPI_Recv", 2}, - {"PMPI_Brecv", 2}, - - {"MPI_Isend", 2}, {"MPI_Irecv", 2}, {"PMPI_Isend", 2}, - {"PMPI_Irecv", 2}, - - {"MPI_Reduce", 3}, {"PMPI_Reduce", 3}, - - {"MPI_Allreduce", 3}, {"PMPI_Allreduce", 3}}; - { - auto found = MPI_TYPE_ARGS.find(F.getName().str()); - if (found != MPI_TYPE_ARGS.end()) { - for (auto user : F.users()) { - if (auto CI = dyn_cast(user)) - if (CI->getCalledFunction() == &F) { - if (Constant *C = - dyn_cast(CI->getArgOperand(found->second))) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_cxx_bool") { - changed = true; - CI->addAttribute( - AttributeList::FunctionIndex, - Attribute::get(CI->getContext(), "enzyme_inactive")); - } - } - } - } - } - } - } - - if (F.getName() == "omp_get_max_threads" || - F.getName() == "omp_get_thread_num") { - changed = true; -#if LLVM_VERSION_MAJOR >= 16 - F.setOnlyAccessesInaccessibleMemory(); - F.setOnlyReadsMemory(); -#else - F.addFnAttr(Attribute::InaccessibleMemOnly); - F.addFnAttr(Attribute::ReadOnly); -#endif - } - if (F.getName() == "frexp" || F.getName() == "frexpf" || - F.getName() == "frexpl") { - changed = true; -#if LLVM_VERSION_MAJOR >= 16 - F.setOnlyAccessesArgMemory(); -#else - F.addFnAttr(Attribute::ArgMemOnly); -#endif - F.addParamAttr(1, Attribute::WriteOnly); - } - if (F.getName() == "__fd_sincos_1" || F.getName() == "__fd_cos_1" || - F.getName() == "__mth_i_ipowi") { - changed = true; -#if LLVM_VERSION_MAJOR >= 16 - F.setOnlyReadsMemory(); - F.setOnlyWritesMemory(); -#else - F.addFnAttr(Attribute::ReadNone); -#endif - } - auto name = F.getName(); - - const char *NonEscapingFns[] = { - "julia.ptls_states", - "julia.get_pgcstack", - "lgamma_r", - "memcmp", - "_ZNSt6chrono3_V212steady_clock3nowEv", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_" - "createERmm", - "_ZNKSt8__detail20_Prime_rehash_policy14_M_need_rehashEmmm", - "fprintf", - "fwrite", - "fputc", - "strtol", - "getenv", - "memchr", - "cublasSetMathMode", - "cublasSetStream_v2", - "cuMemPoolTrimTo", - "cuDeviceGetMemPool", - "cuStreamSynchronize", - "cuStreamDestroy", - "cuStreamQuery", - "cuCtxGetCurrent", - "cuDeviceGet", - "cuDeviceGetName", - "cuDriverGetVersion", - "cudaRuntimeGetVersion", - "cuDeviceGetCount", - "cuMemPoolGetAttribute", - "cuMemGetInfo_v2", - "cuDeviceGetAttribute", - "cuDevicePrimaryCtxRetain", - }; - for (auto fname : NonEscapingFns) - if (name == fname) { - changed = true; - F.addAttribute( - AttributeList::FunctionIndex, - Attribute::get(F.getContext(), "enzyme_no_escaping_allocation")); - } - changed |= attributeTablegen(F); - return changed; -} - -namespace { -static Value * -castToDiffeFunctionArgType(IRBuilder<> &Builder, llvm::CallInst *CI, - llvm::FunctionType *FT, llvm::Type *destType, - unsigned int i, DerivativeMode mode, - llvm::Value *value, unsigned int truei) { - auto res = value; - if (auto ptr = dyn_cast(res->getType())) { - if (auto PT = dyn_cast(destType)) { - if (ptr->getAddressSpace() != PT->getAddressSpace()) { -#if LLVM_VERSION_MAJOR < 17 - if (CI->getContext().supportsTypedPointers()) { - res = Builder.CreateAddrSpaceCast( - res, PointerType::get(ptr->getPointerElementType(), - PT->getAddressSpace())); - } else { - res = Builder.CreateAddrSpaceCast(res, PT); - } -#else - res = Builder.CreateAddrSpaceCast(res, PT); -#endif - assert(value); - assert(destType); - assert(FT); - llvm::errs() << "Warning cast(2) __enzyme_autodiff argument " << i - << " " << *res << "|" << *res->getType() << " to argument " - << truei << " " << *destType << "\n" - << "orig: " << *FT << "\n"; - return res; - } - } - } - - if (!res->getType()->canLosslesslyBitCastTo(destType)) { - assert(value); - assert(value->getType()); - assert(destType); - assert(FT); - auto loc = CI->getDebugLoc(); - if (auto arg = dyn_cast(res)) { - loc = arg->getDebugLoc(); - } - EmitFailure("IllegalArgCast", loc, CI, - "Cannot cast __enzyme_autodiff shadow argument ", i, ", found ", - *res, ", type ", *res->getType(), " - to arg ", truei, " ", - *destType); - return nullptr; - } - return Builder.CreateBitCast(value, destType); -} - -#if LLVM_VERSION_MAJOR > 16 -static std::optional getMetadataName(llvm::Value *res); -#else -static Optional getMetadataName(llvm::Value *res); -#endif - -// if all phi arms are (recursively) based on the same metaString, use that -#if LLVM_VERSION_MAJOR > 16 -static std::optional recursePhiReads(PHINode *val) -#else -static Optional recursePhiReads(PHINode *val) -#endif -{ -#if LLVM_VERSION_MAJOR > 16 - std::optional finalMetadata; -#else - Optional finalMetadata; -#endif - SmallVector todo = {val}; - SmallSet done; - while (todo.size()) { - auto phiInst = todo.back(); - todo.pop_back(); - if (done.count(phiInst)) - continue; - done.insert(phiInst); - for (unsigned j = 0; j < phiInst->getNumIncomingValues(); ++j) { - auto newVal = phiInst->getIncomingValue(j); - if (auto phi = dyn_cast(newVal)) { - todo.push_back(phi); - } else { - auto metaString = getMetadataName(newVal); - if (metaString) { - if (!finalMetadata) { - finalMetadata = metaString; - } else if (finalMetadata != metaString) { - return {}; - } - } - } - } - } - return finalMetadata; -} - -#if LLVM_VERSION_MAJOR > 16 -std::optional getMetadataName(llvm::Value *res) -#else -Optional getMetadataName(llvm::Value *res) -#endif -{ - if (auto S = simplifyLoad(res)) - return getMetadataName(S); - - if (auto av = dyn_cast(res)) { - return cast(av->getMetadata())->getString(); - } else if ((isa(res) || isa(res)) && - isa(cast(res)->getOperand(0))) { - GlobalVariable *gv = - cast(cast(res)->getOperand(0)); - return gv->getName(); - } else if (isa(res) && - isa(cast(res)->getOperand(0)) && - cast(cast(res)->getOperand(0))->isCast() && - isa( - cast(cast(res)->getOperand(0)) - ->getOperand(0))) { - auto gv = cast( - cast(cast(res)->getOperand(0))->getOperand(0)); - return gv->getName(); - } else if (auto gv = dyn_cast(res)) { - return gv->getName(); - } else if (isa(res) && cast(res)->isCast() && - isa(cast(res)->getOperand(0))) { - auto gv = cast(cast(res)->getOperand(0)); - return gv->getName(); - } else if (isa(res) && cast(res) && - isa(cast(res)->getOperand(0))) { - auto gv = cast(cast(res)->getOperand(0)); - return gv->getName(); - } else if (auto gv = dyn_cast(res)) { - return gv->getName(); - } else if (isa(res)) { - return recursePhiReads(cast(res)); - } - - return {}; -} - -static Value *adaptReturnedVector(Value *ret, Value *diffret, - IRBuilder<> &Builder, unsigned width) { - Type *returnType = ret->getType(); - - if (StructType *sty = dyn_cast(returnType)) { - Value *agg = ConstantAggregateZero::get(sty); - - for (unsigned int i = 0; i < width; i++) { - Value *elem = Builder.CreateExtractValue(diffret, {i}); - if (auto vty = dyn_cast(elem->getType())) { - for (unsigned j = 0; j < vty->getNumElements(); ++j) { - Value *vecelem = Builder.CreateExtractElement(elem, j); - agg = Builder.CreateInsertValue(agg, vecelem, {i * j}); - } - } else { - agg = Builder.CreateInsertValue(agg, elem, {i}); - } - } - diffret = agg; - } - return diffret; -} - -static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret, - Type *retElemType, Value *diffret, - Instruction *CI, DerivativeMode mode) { - Type *retType = ret->getType(); - Type *diffretType = diffret->getType(); - auto &DL = CI->getModule()->getDataLayout(); - - if (diffretType->isEmptyTy() || diffretType->isVoidTy() || - retType->isEmptyTy() || retType->isVoidTy()) { - CI->replaceAllUsesWith(UndefValue::get(CI->getType())); - CI->eraseFromParent(); - return true; - } - - if (retType == diffretType) { - CI->replaceAllUsesWith(diffret); - CI->eraseFromParent(); - return true; - } - - if (auto sretType = dyn_cast(retType), - diffsretType = dyn_cast(diffretType); - sretType && diffsretType && sretType->isLayoutIdentical(diffsretType)) { - Value *newStruct = UndefValue::get(sretType); - for (unsigned int i = 0; i < sretType->getStructNumElements(); i++) { - Value *elem = Builder.CreateExtractValue(diffret, {i}); - newStruct = Builder.CreateInsertValue(newStruct, elem, {i}); - } - CI->replaceAllUsesWith(newStruct); - CI->eraseFromParent(); - return true; - } - - if (isa(retType)) { - retType = retElemType; - - if (auto sretType = dyn_cast(retType), - diffsretType = dyn_cast(diffretType); - sretType && diffsretType && sretType->isLayoutIdentical(diffsretType)) { - for (unsigned int i = 0; i < sretType->getStructNumElements(); i++) { - Value *sgep = Builder.CreateStructGEP(retType, ret, i); - Builder.CreateStore(Builder.CreateExtractValue(diffret, {i}), sgep); - } - CI->eraseFromParent(); - return true; - } - - if (DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) { - Builder.CreateStore( - diffret, - Builder.CreatePointerCast(ret, PointerType::getUnqual(diffretType))); - CI->eraseFromParent(); - return true; - } - } - - if ((mode == DerivativeMode::ReverseModePrimal && - DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) || - ((mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError) && - DL.getTypeSizeInBits(retType) == DL.getTypeSizeInBits(diffretType))) { - IRBuilder<> EB(CI->getFunction()->getEntryBlock().getFirstNonPHI()); - auto AL = EB.CreateAlloca(retType); - Builder.CreateStore(diffret, Builder.CreatePointerCast( - AL, PointerType::getUnqual(diffretType))); - Value *cload = Builder.CreateLoad(retType, AL); - CI->replaceAllUsesWith(cload); - CI->eraseFromParent(); - return true; - } - - if (mode != DerivativeMode::ReverseModePrimal && - diffret->getType()->isAggregateType()) { - auto diffreti = Builder.CreateExtractValue(diffret, {0}); - if (diffreti->getType() == retType) { - CI->replaceAllUsesWith(diffreti); - CI->eraseFromParent(); - return true; - } else if (diffretType == retType) { - CI->replaceAllUsesWith(diffret); - CI->eraseFromParent(); - return true; - } - } - - auto diffretsize = DL.getTypeSizeInBits(diffretType); - auto retsize = DL.getTypeSizeInBits(retType); - EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI, - "Cannot cast return type of gradient ", *diffretType, *diffret, - " of size ", diffretsize, " bits ", ", to desired type ", - *retType, " of size ", retsize, " bits"); - return false; -} - -class EnzymeBase { -public: - EnzymeLogic Logic; - EnzymeBase(bool PostOpt) - : Logic(EnzymePostOpt.getNumOccurrences() ? EnzymePostOpt : PostOpt) { - // initializeLowerAutodiffIntrinsicPass(*PassRegistry::getPassRegistry()); - } - - Function *parseFunctionParameter(CallInst *CI) { - Value *fn = CI->getArgOperand(0); - - // determine function to differentiate - if (CI->hasStructRetAttr()) { - fn = CI->getArgOperand(1); - } - - Value *ofn = fn; - fn = GetFunctionFromValue(fn); - - if (!fn || !isa(fn)) { - assert(ofn); - EmitFailure("NoFunctionToDifferentiate", CI->getDebugLoc(), CI, - "failed to find fn to differentiate", *CI, " - found - ", - *ofn); - return nullptr; - } - if (cast(fn)->empty()) { - EmitFailure("EmptyFunctionToDifferentiate", CI->getDebugLoc(), CI, - "failed to find fn to differentiate", *CI, " - found - ", - *fn); - return nullptr; - } - - return cast(fn); - } - -#if LLVM_VERSION_MAJOR > 16 - static std::optional parseWidthParameter(CallInst *CI) -#else - static Optional parseWidthParameter(CallInst *CI) -#endif - { - unsigned width = 1; - - for (auto [i, found] = std::tuple{0u, false}; i < CI->arg_size(); ++i) { - Value *arg = CI->getArgOperand(i); - - if (auto MDName = getMetadataName(arg)) { - if (*MDName == "enzyme_width") { - if (found) { - EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, - "vector width declared more than once", - *CI->getArgOperand(i), " in", *CI); - return {}; - } - - if (i + 1 >= CI->arg_size()) { - EmitFailure("MissingVectorWidth", CI->getDebugLoc(), CI, - "constant integer followong enzyme_width is missing", - *CI->getArgOperand(i), " in", *CI); - return {}; - } - - Value *width_arg = CI->getArgOperand(i + 1); - if (auto cint = dyn_cast(width_arg)) { - width = cint->getZExtValue(); - found = true; - } else { - EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, - "enzyme_width must be a constant integer", - *CI->getArgOperand(i), " in", *CI); - return {}; - } - - if (!found) { - EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, - "illegal enzyme vector argument width ", - *CI->getArgOperand(i), " in", *CI); - return {}; - } - } - } - } - return width; - } - - struct Options { - Value *differet; - Value *tape; - Value *dynamic_interface; - Value *trace; - Value *observations; - Value *likelihood; - Value *diffeLikelihood; - unsigned width; - int allocatedTapeSize; - bool freeMemory; - bool returnUsed; - bool tapeIsPointer; - bool differentialReturn; - bool diffeTrace; - DIFFE_TYPE retType; - bool primalReturn; - StringSet<> ActiveRandomVariables; - std::vector overwritten_args; - bool runtimeActivity; - bool subsequent_calls_may_write; - }; - -#if LLVM_VERSION_MAJOR > 16 - static std::optional - handleArguments(IRBuilder<> &Builder, CallInst *CI, Function *fn, - DerivativeMode mode, bool sizeOnly, - std::vector &constants, - SmallVectorImpl &args, std::map &byVal) -#else - static Optional - handleArguments(IRBuilder<> &Builder, CallInst *CI, Function *fn, - DerivativeMode mode, bool sizeOnly, - std::vector &constants, - SmallVectorImpl &args, std::map &byVal) -#endif - { - FunctionType *FT = fn->getFunctionType(); - - Value *differet = nullptr; - Value *tape = nullptr; - Value *dynamic_interface = nullptr; - Value *trace = nullptr; - Value *observations = nullptr; - Value *likelihood = nullptr; - Value *diffeLikelihood = nullptr; - unsigned width = 1; - int allocatedTapeSize = -1; - bool freeMemory = true; - bool tapeIsPointer = false; - bool diffeTrace = false; - unsigned truei = 0; - unsigned byRefSize = 0; - bool primalReturn = false; - bool runtimeActivity = false; - bool subsequent_calls_may_write = - mode != DerivativeMode::ForwardMode && - mode != DerivativeMode::ForwardModeError && - mode != DerivativeMode::ReverseModeCombined; - StringSet<> ActiveRandomVariables; - - DIFFE_TYPE retType = whatType(fn->getReturnType(), mode); - - if (fn->hasParamAttribute(0, Attribute::StructRet)) { - Type *Ty = nullptr; - Ty = fn->getParamAttribute(0, Attribute::StructRet).getValueAsType(); - if (whatType(Ty, mode) != DIFFE_TYPE::CONSTANT) { - retType = DIFFE_TYPE::DUP_ARG; - } - } - - bool returnUsed = - !fn->getReturnType()->isVoidTy() && !fn->getReturnType()->isEmptyTy(); - - bool sret = CI->hasStructRetAttr() || - fn->hasParamAttribute(0, Attribute::StructRet); - - std::vector overwritten_args( - fn->getFunctionType()->getNumParams(), - !(mode == DerivativeMode::ReverseModeCombined)); - - for (unsigned i = 1 + sret; i < CI->arg_size(); ++i) { - Value *res = CI->getArgOperand(i); - auto metaString = getMetadataName(res); - // handle metadata - if (metaString && startsWith(*metaString, "enzyme_")) { - if (*metaString == "enzyme_const_return") { - retType = DIFFE_TYPE::CONSTANT; - continue; - } else if (*metaString == "enzyme_active_return") { - retType = DIFFE_TYPE::OUT_DIFF; - continue; - } else if (*metaString == "enzyme_dup_return") { - retType = DIFFE_TYPE::DUP_ARG; - continue; - } else if (*metaString == "enzyme_noret") { - returnUsed = false; - continue; - } else if (*metaString == "enzyme_primal_return") { - primalReturn = true; - continue; - } - } - } - bool differentialReturn = (mode == DerivativeMode::ReverseModeCombined || - mode == DerivativeMode::ReverseModeGradient) && - (retType == DIFFE_TYPE::OUT_DIFF); - - // find and handle enzyme_width - if (auto parsedWidth = parseWidthParameter(CI)) { - width = *parsedWidth; - } else { - return {}; - } - - // handle different argument order for struct return. - if (fn->hasParamAttribute(0, Attribute::StructRet)) { - truei = 1; - - const DataLayout &DL = CI->getParent()->getModule()->getDataLayout(); - Type *Ty = nullptr; - Ty = fn->getParamAttribute(0, Attribute::StructRet).getValueAsType(); - Type *CTy = nullptr; - CTy = CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet) - .getValueAsType(); - auto FnSize = (DL.getTypeSizeInBits(Ty) / 8); - auto CSize = CTy ? (DL.getTypeSizeInBits(CTy) / 8) : 0; - auto count = ((mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit || - mode == DerivativeMode::ForwardModeError) && - (retType == DIFFE_TYPE::DUP_ARG || - retType == DIFFE_TYPE::DUP_NONEED)) * - width + - primalReturn; - if (CSize < count * FnSize) { - EmitFailure( - "IllegalByRefSize", CI->getDebugLoc(), CI, "Struct return type ", - *CTy, " (", CSize, " bytes), not large enough to store ", count, - " returns of type ", *Ty, " (", FnSize, " bytes), width=", width, - " primal requested=", primalReturn); - } - Value *primal = nullptr; - if (primalReturn) { - Value *sretPt = CI->getArgOperand(0); - PointerType *pty = cast(sretPt->getType()); - primal = Builder.CreatePointerCast( - sretPt, PointerType::get(Ty, pty->getAddressSpace())); - } else { - AllocaInst *primalA = new AllocaInst(Ty, DL.getAllocaAddrSpace(), - nullptr, DL.getPrefTypeAlign(Ty)); - primalA->insertBefore(CI); - primal = primalA; - } - - Value *shadow = nullptr; - switch (mode) { - case DerivativeMode::ForwardModeError: - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: { - if (retType != DIFFE_TYPE::CONSTANT) { - Value *sretPt = CI->getArgOperand(0); - PointerType *pty = cast(sretPt->getType()); - auto shadowPtr = Builder.CreatePointerCast( - sretPt, PointerType::get(Ty, pty->getAddressSpace())); - if (width == 1) { - if (primalReturn) - shadowPtr = Builder.CreateConstGEP1_64(Ty, shadowPtr, 1); - shadow = shadowPtr; - } else { - Value *acc = UndefValue::get(ArrayType::get( - PointerType::get(Ty, pty->getAddressSpace()), width)); - for (size_t i = 0; i < width; ++i) { - Value *elem = - Builder.CreateConstGEP1_64(Ty, shadowPtr, i + primalReturn); - acc = Builder.CreateInsertValue(acc, elem, i); - } - shadow = acc; - } - } - break; - } - case DerivativeMode::ReverseModePrimal: - case DerivativeMode::ReverseModeCombined: - case DerivativeMode::ReverseModeGradient: { - if (retType != DIFFE_TYPE::CONSTANT) - shadow = CI->getArgOperand(1); - sret = true; - break; - } - } - - args.push_back(primal); - if (retType != DIFFE_TYPE::CONSTANT) - args.push_back(shadow); - if (retType == DIFFE_TYPE::DUP_ARG && !primalReturn && isWriteOnly(fn, 0)) - retType = DIFFE_TYPE::DUP_NONEED; - constants.push_back(retType); - retType = DIFFE_TYPE::CONSTANT; - primalReturn = false; - } - - ssize_t interleaved = -1; - - size_t maxsize; - maxsize = CI->arg_size(); - size_t num_args = maxsize; - for (unsigned i = 1 + sret; i < maxsize; ++i) { - Value *res = CI->getArgOperand(i); - auto metaString = getMetadataName(res); - if (metaString && startsWith(*metaString, "enzyme_")) { - if (*metaString == "enzyme_interleave") { - maxsize = i; - interleaved = i + 1; - break; - } - } - } - - DIFFE_TYPE last_ty = DIFFE_TYPE::DUP_ARG; - - for (ssize_t i = 1 + sret; (size_t)i < maxsize; ++i) { - Value *res = CI->getArgOperand(i); - auto metaString = getMetadataName(res); -#if LLVM_VERSION_MAJOR > 16 - std::optional batchOffset; - std::optional opt_ty; -#else - Optional batchOffset; - Optional opt_ty; -#endif - - bool overwritten = !(mode == DerivativeMode::ReverseModeCombined); - - bool skipArg = false; - - // handle metadata - while (metaString && startsWith(*metaString, "enzyme_")) { - if (*metaString == "enzyme_not_overwritten") { - overwritten = false; - } else if (*metaString == "enzyme_byref") { - ++i; - if (!isa(CI->getArgOperand(i))) { - EmitFailure("IllegalAllocatedSize", CI->getDebugLoc(), CI, - "illegal enzyme byref size ", *CI->getArgOperand(i), - "in", *CI); - return {}; - } - byRefSize = cast(CI->getArgOperand(i))->getZExtValue(); - assert(byRefSize > 0); - skipArg = true; - break; - } else if (*metaString == "enzyme_dup") { - opt_ty = DIFFE_TYPE::DUP_ARG; - } else if (*metaString == "enzyme_dupv") { - opt_ty = DIFFE_TYPE::DUP_ARG; - ++i; - Value *offset_arg = CI->getArgOperand(i); - if (offset_arg->getType()->isIntegerTy()) { - batchOffset = offset_arg; - } else { - EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI, - "enzyme_batch must be followd by an integer " - "offset.", - *CI->getArgOperand(i), " in", *CI); - return {}; - } - } else if (*metaString == "enzyme_dupnoneed") { - opt_ty = DIFFE_TYPE::DUP_NONEED; - } else if (*metaString == "enzyme_dupnoneedv") { - opt_ty = DIFFE_TYPE::DUP_NONEED; - ++i; - Value *offset_arg = CI->getArgOperand(i); - if (offset_arg->getType()->isIntegerTy()) { - batchOffset = offset_arg; - } else { - EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI, - "enzyme_batch must be followd by an integer " - "offset.", - *CI->getArgOperand(i), " in", *CI); - return {}; - } - } else if (*metaString == "enzyme_out") { - opt_ty = DIFFE_TYPE::OUT_DIFF; - } else if (*metaString == "enzyme_const") { - opt_ty = DIFFE_TYPE::CONSTANT; - } else if (*metaString == "enzyme_noret") { - skipArg = true; - break; - } else if (*metaString == "enzyme_allocated") { - assert(!sizeOnly); - ++i; - if (!isa(CI->getArgOperand(i))) { - EmitFailure("IllegalAllocatedSize", CI->getDebugLoc(), CI, - "illegal enzyme allocated size ", *CI->getArgOperand(i), - "in", *CI); - return {}; - } - allocatedTapeSize = - cast(CI->getArgOperand(i))->getZExtValue(); - skipArg = true; - break; - } else if (*metaString == "enzyme_tape") { - assert(!sizeOnly); - ++i; - tape = CI->getArgOperand(i); - tapeIsPointer = true; - skipArg = true; - break; - } else if (*metaString == "enzyme_nofree") { - assert(!sizeOnly); - freeMemory = false; - skipArg = true; - break; - } else if (*metaString == "enzyme_runtime_activity") { - runtimeActivity = true; - skipArg = true; - break; - } else if (*metaString == "enzyme_primal_return") { - skipArg = true; - break; - } else if (*metaString == "enzyme_const_return") { - skipArg = true; - break; - } else if (*metaString == "enzyme_active_return") { - skipArg = true; - break; - } else if (*metaString == "enzyme_dup_return") { - skipArg = true; - break; - } else if (*metaString == "enzyme_width") { - ++i; - skipArg = true; - break; - } else if (*metaString == "enzyme_interface") { - ++i; - dynamic_interface = CI->getArgOperand(i); - skipArg = true; - break; - } else if (*metaString == "enzyme_trace") { - trace = CI->getArgOperand(++i); - opt_ty = DIFFE_TYPE::CONSTANT; - skipArg = true; - break; - } else if (*metaString == "enzyme_duptrace") { - trace = CI->getArgOperand(++i); - diffeTrace = true; - opt_ty = DIFFE_TYPE::CONSTANT; - skipArg = true; - break; - } else if (*metaString == "enzyme_likelihood") { - likelihood = CI->getArgOperand(++i); - opt_ty = DIFFE_TYPE::CONSTANT; - skipArg = true; - break; - } else if (*metaString == "enzyme_duplikelihood") { - likelihood = CI->getArgOperand(++i); - diffeLikelihood = CI->getArgOperand(++i); - opt_ty = DIFFE_TYPE::DUP_ARG; - skipArg = true; - break; - } else if (*metaString == "enzyme_observations") { - observations = CI->getArgOperand(++i); - opt_ty = DIFFE_TYPE::CONSTANT; - skipArg = true; - break; - } else if (*metaString == "enzyme_active_rand_var") { - Value *string = CI->getArgOperand(++i); - StringRef const_string; - if (getConstantStringInfo(string, const_string)) { - ActiveRandomVariables.insert(const_string); - } else { - EmitFailure( - "IllegalStringType", CI->getDebugLoc(), CI, - "active variable address must be a compile-time constant", *CI, - *metaString); - } - skipArg = true; - break; - } else { - EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI, - "illegal enzyme metadata classification ", *CI, - *metaString); - return {}; - } - if (sizeOnly) { - assert(opt_ty); - constants.push_back(*opt_ty); - truei++; - skipArg = true; - break; - } - ++i; - if (i == CI->arg_size()) { - EmitFailure("EnzymeCallingError", CI->getDebugLoc(), CI, - "Too few arguments to Enzyme call ", *CI); - return {}; - } - res = CI->getArgOperand(i); - metaString = getMetadataName(res); - } - - if (skipArg) - continue; - - if (byRefSize) { - Type *subTy = nullptr; - if (truei < FT->getNumParams()) { - subTy = FT->getParamType(i); - } else if ((mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ForwardModeSplit)) { - if (differentialReturn && differet == nullptr) { - subTy = FT->getReturnType(); - } - } - - if (!subTy) { - EmitFailure("IllegalByVal", CI->getDebugLoc(), CI, - "illegal enzyme byval arg", truei, " ", *res); - return {}; - } - - auto &DL = fn->getParent()->getDataLayout(); - auto BitSize = DL.getTypeSizeInBits(subTy); - if (BitSize / 8 != byRefSize) { - EmitFailure("IllegalByRefSize", CI->getDebugLoc(), CI, - "illegal enzyme pointer type size ", *res, " expected ", - byRefSize, " (bytes) actual size ", BitSize, - " (bits) in ", *CI); - } - res = Builder.CreateBitCast( - res, - PointerType::get( - subTy, cast(res->getType())->getAddressSpace())); - res = Builder.CreateLoad(subTy, res); - byRefSize = 0; - } - - if (truei >= FT->getNumParams()) { - if (!isa(res) && - (mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ForwardModeSplit)) { - if (differentialReturn && differet == nullptr) { - differet = res; - if (CI->paramHasAttr(i, Attribute::ByVal)) { - Type *T = nullptr; - T = CI->getParamAttr(i, Attribute::ByVal).getValueAsType(); - differet = Builder.CreateLoad(T, differet); - } - if (differet->getType() != fn->getReturnType()) - if (auto ST0 = dyn_cast(differet->getType())) - if (auto ST1 = dyn_cast(fn->getReturnType())) - if (ST0->isLayoutIdentical(ST1)) { - IRBuilder<> B(&Builder.GetInsertBlock() - ->getParent() - ->getEntryBlock() - .front()); - auto AI = B.CreateAlloca(ST1); - Builder.CreateStore(differet, - Builder.CreatePointerCast( - AI, PointerType::getUnqual(ST0))); - differet = Builder.CreateLoad(ST1, AI); - } - - if (differet->getType() != - GradientUtils::getShadowType(fn->getReturnType(), width)) { - EmitFailure("BadDiffRet", CI->getDebugLoc(), CI, - "Bad DiffRet type ", *differet, " expected ", - *fn->getReturnType()); - return {}; - } - continue; - } else if (tape == nullptr) { - tape = res; - if (CI->paramHasAttr(i, Attribute::ByVal)) { - Type *T = nullptr; - T = CI->getParamAttr(i, Attribute::ByVal).getValueAsType(); - tape = Builder.CreateLoad(T, tape); - } - continue; - } - } - EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, - "Had too many arguments to __enzyme_autodiff", *CI, - " - extra arg - ", *res); - return {}; - } - assert(truei < FT->getNumParams()); - overwritten_args[truei] = overwritten; - - auto PTy = FT->getParamType(truei); - DIFFE_TYPE ty = - opt_ty ? *opt_ty - : ((interleaved == -1) ? whatType(PTy, mode) : last_ty); - last_ty = ty; - - constants.push_back(ty); - - assert(truei < FT->getNumParams()); - // cast primal - if (PTy != res->getType()) { - if (auto ptr = dyn_cast(res->getType())) { - if (auto PT = dyn_cast(PTy)) { - if (ptr->getAddressSpace() != PT->getAddressSpace()) { -#if LLVM_VERSION_MAJOR < 17 - if (CI->getContext().supportsTypedPointers()) { - res = Builder.CreateAddrSpaceCast( - res, PointerType::get(ptr->getPointerElementType(), - PT->getAddressSpace())); - } else { - res = Builder.CreateAddrSpaceCast(res, PT); - } -#else - res = Builder.CreateAddrSpaceCast(res, PT); -#endif - assert(res); - assert(PTy); - assert(FT); - llvm::errs() << "Warning cast(1) __enzyme_autodiff argument " << i - << " " << *res << "|" << *res->getType() - << " to argument " << truei << " " << *PTy << "\n" - << "orig: " << *FT << "\n"; - } - } - } - if (res->getType()->canLosslesslyBitCastTo(PTy)) { - res = Builder.CreateBitCast(res, PTy); - } - if (res->getType() != PTy && res->getType()->isIntegerTy() && - PTy->isIntegerTy(1)) { - res = Builder.CreateTrunc(res, PTy); - } - if (res->getType() != PTy) { - auto loc = CI->getDebugLoc(); - if (auto arg = dyn_cast(res)) { - loc = arg->getDebugLoc(); - } - auto S = simplifyLoad(res); - if (!S) - S = res; - EmitFailure("IllegalArgCast", loc, CI, - "Cannot cast __enzyme_autodiff primal argument ", i, - ", found ", *res, ", type ", *res->getType(), - " (simplified to ", *S, " ) ", " - to arg ", truei, ", ", - *PTy); - return {}; - } - } - if (CI->isByValArgument(i)) { - byVal[args.size()] = CI->getParamByValType(i); - } - - args.push_back(res); - if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) { - if (interleaved == -1) - ++i; - - Value *res = nullptr; -#if LLVM_VERSION_MAJOR >= 16 - bool batch = batchOffset.has_value(); -#else - bool batch = batchOffset.hasValue(); -#endif - - for (unsigned v = 0; v < width; ++v) { - if ((size_t)((interleaved == -1) ? i : interleaved) >= num_args) { - EmitFailure("MissingArgShadow", CI->getDebugLoc(), CI, - "__enzyme_autodiff missing argument shadow at index ", - *((interleaved == -1) ? &i : &interleaved), - ", need shadow of type ", *PTy, - " to shadow primal argument ", *args.back(), - " at call ", *CI); - return {}; - } - - // cast diffe - Value *element = - CI->getArgOperand((interleaved == -1) ? i : interleaved); - if (batch) { - if (auto elementPtrTy = dyn_cast(element->getType())) { - element = Builder.CreateBitCast( - element, PointerType::get(Type::getInt8Ty(CI->getContext()), - elementPtrTy->getAddressSpace())); - element = Builder.CreateGEP( - Type::getInt8Ty(CI->getContext()), element, - Builder.CreateMul( - *batchOffset, - ConstantInt::get((*batchOffset)->getType(), v))); - element = Builder.CreateBitCast(element, elementPtrTy); - } else { - EmitFailure( - "NonPointerBatch", CI->getDebugLoc(), CI, - "Batched argument at index ", - *((interleaved == -1) ? &i : &interleaved), - " must be of pointer type, found: ", *element->getType()); - return {}; - } - } - if (PTy != element->getType()) { - element = castToDiffeFunctionArgType( - Builder, CI, FT, PTy, (interleaved == -1) ? i : interleaved, - mode, element, truei); - if (!element) { - return {}; - } - } - - if (width > 1) { - res = - res ? Builder.CreateInsertValue(res, element, {v}) - : Builder.CreateInsertValue(UndefValue::get(ArrayType::get( - element->getType(), width)), - element, {v}); - - if (v < width - 1 && !batch && (interleaved == -1)) { - ++i; - } - - } else { - res = element; - } - - if (interleaved != -1) - interleaved++; - } - - args.push_back(res); - } - - ++truei; - } - if (truei < FT->getNumParams()) { - auto numParams = FT->getNumParams(); - EmitFailure( - "EnzymeInsufficientArgs", CI->getDebugLoc(), CI, - "Insufficient number of args passed to derivative call required ", - numParams, " primal args, found ", truei); - return {}; - } - - return Options({differet, - tape, - dynamic_interface, - trace, - observations, - likelihood, - diffeLikelihood, - width, - allocatedTapeSize, - freeMemory, - returnUsed, - tapeIsPointer, - differentialReturn, - diffeTrace, - retType, - primalReturn, - ActiveRandomVariables, - overwritten_args, - runtimeActivity, - subsequent_calls_may_write}); - } - - static FnTypeInfo populate_type_args(TypeAnalysis &TA, llvm::Function *fn, - DerivativeMode mode) { - FnTypeInfo type_args(fn); - for (auto &a : type_args.Function->args()) { - TypeTree dt; - if (a.getType()->isFPOrFPVectorTy()) { - dt = ConcreteType(a.getType()->getScalarType()); - } else if (a.getType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR < 17 - if (a.getContext().supportsTypedPointers()) { - auto et = a.getType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); - } - } -#endif - dt.insert({}, BaseType::Pointer); - } else if (a.getType()->isIntOrIntVectorTy()) { - dt = ConcreteType(BaseType::Integer); - } - type_args.Arguments.insert( - std::pair(&a, dt.Only(-1, nullptr))); - // TODO note that here we do NOT propagate constants in type info (and - // should consider whether we should) - type_args.KnownValues.insert( - std::pair>(&a, {})); - } - TypeTree dt; - if (fn->getReturnType()->isFPOrFPVectorTy()) { - dt = ConcreteType(fn->getReturnType()->getScalarType()); - } - type_args.Return = dt.Only(-1, nullptr); - - type_args = TA.analyzeFunction(type_args).getAnalyzedTypeInfo(); - return type_args; - } - - static FloatRepresentation getDefaultFloatRepr(unsigned width) { - switch (width) { - case 16: - return FloatRepresentation(5, 10); - case 32: - return FloatRepresentation(8, 23); - case 64: - return FloatRepresentation(11, 52); - default: - llvm_unreachable("Invalid float width"); - } - }; - - bool HandleTruncateFunc(CallInst *CI, TruncateMode mode) { - IRBuilder<> Builder(CI); - Function *F = parseFunctionParameter(CI); - if (!F) - return false; - unsigned ArgSize = CI->arg_size(); - if (ArgSize != 4 && ArgSize != 3) { - EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, - "Had incorrect number of args to __enzyme_truncate_func", *CI, - " - expected 3 or 4"); - return false; - } - FloatTruncation truncation = [&]() -> FloatTruncation { - if (ArgSize == 3) { - auto Cfrom = cast(CI->getArgOperand(1)); - assert(Cfrom); - auto Cto = cast(CI->getArgOperand(2)); - assert(Cto); - return FloatTruncation( - getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), - getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), - mode); - } else if (ArgSize == 4) { - auto Cfrom = cast(CI->getArgOperand(1)); - assert(Cfrom); - auto Cto_exponent = cast(CI->getArgOperand(2)); - assert(Cto_exponent); - auto Cto_significand = cast(CI->getArgOperand(3)); - assert(Cto_significand); - return FloatTruncation( - getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), - FloatRepresentation( - (unsigned)Cto_exponent->getValue().getZExtValue(), - (unsigned)Cto_significand->getValue().getZExtValue()), - mode); - } - llvm_unreachable("??"); - }(); - - RequestContext context(CI, &Builder); - llvm::Value *res = Logic.CreateTruncateFunc(context, F, truncation, mode); - if (!res) - return false; - res = Builder.CreatePointerCast(res, CI->getType()); - CI->replaceAllUsesWith(res); - CI->eraseFromParent(); - return true; - } - - bool HandleTruncateValue(CallInst *CI, bool isTruncate) { - IRBuilder<> Builder(CI); - if (CI->arg_size() != 3) { - EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, - "Had incorrect number of args to __enzyme_truncate_value", - *CI, " - expected 3"); - return false; - } - auto Cfrom = cast(CI->getArgOperand(1)); - assert(Cfrom); - auto Cto = cast(CI->getArgOperand(2)); - assert(Cto); - auto Addr = CI->getArgOperand(0); - RequestContext context(CI, &Builder); - bool res = Logic.CreateTruncateValue( - context, Addr, - getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), - getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), - isTruncate); - if (!res) - return false; - return true; - } - - bool HandleBatch(CallInst *CI) { - unsigned width = 1; - unsigned truei = 0; - std::map batchOffset; - SmallVector args; - SmallVector arg_types; - IRBuilder<> Builder(CI); - Function *F = parseFunctionParameter(CI); - if (!F) - return false; - - assert(F); - FunctionType *FT = F->getFunctionType(); - - // find and handle enzyme_width - if (auto parsedWidth = parseWidthParameter(CI)) { - width = *parsedWidth; - } else { - return false; - } - - // handle different argument order for struct return. - bool sret = - CI->hasStructRetAttr() || F->hasParamAttribute(0, Attribute::StructRet); - - if (F->hasParamAttribute(0, Attribute::StructRet)) { - truei = 1; - Value *sretPt = CI->getArgOperand(0); - - args.push_back(sretPt); - arg_types.push_back(BATCH_TYPE::VECTOR); - } - - for (unsigned i = 1 + sret; i < CI->arg_size(); ++i) { - Value *res = CI->getArgOperand(i); - - if (truei >= FT->getNumParams()) { - EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, - "Had too many arguments to __enzyme_batch", *CI, - " - extra arg - ", *res); - return false; - } - assert(truei < FT->getNumParams()); - auto PTy = FT->getParamType(truei); - - BATCH_TYPE ty = width == 1 ? BATCH_TYPE::SCALAR : BATCH_TYPE::VECTOR; - auto metaString = getMetadataName(res); - - // handle metadata - if (metaString && startsWith(*metaString, "enzyme_")) { - if (*metaString == "enzyme_scalar") { - ty = BATCH_TYPE::SCALAR; - } else if (*metaString == "enzyme_vector") { - ty = BATCH_TYPE::VECTOR; - } else if (*metaString == "enzyme_buffer") { - ty = BATCH_TYPE::VECTOR; - ++i; - Value *offset_arg = CI->getArgOperand(i); - if (offset_arg->getType()->isIntegerTy()) { - batchOffset[i + 1] = offset_arg; - } else { - EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI, - "enzyme_batch must be followd by an integer " - "offset.", - *CI->getArgOperand(i), " in", *CI); - return false; - } - continue; - } else if (*metaString == "enzyme_width") { - ++i; - continue; - } else { - EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI, - "illegal enzyme metadata classification ", *CI, - *metaString); - return false; - } - ++i; - res = CI->getArgOperand(i); - } - - arg_types.push_back(ty); - - // wrap vector - if (ty == BATCH_TYPE::VECTOR) { - Value *res = nullptr; - bool batch = batchOffset.count(i - 1) != 0; - - for (unsigned v = 0; v < width; ++v) { - if (i >= CI->arg_size()) { - EmitFailure("MissingVectorArg", CI->getDebugLoc(), CI, - "__enzyme_batch missing vector argument at index ", i, - ", need argument of type ", *PTy, " at call ", *CI); - return false; - } - - // vectorize pointer - Value *element = CI->getArgOperand(i); - if (batch) { - if (auto elementPtrTy = dyn_cast(element->getType())) { - element = Builder.CreateBitCast( - element, PointerType::get(Type::getInt8Ty(CI->getContext()), - elementPtrTy->getAddressSpace())); - element = Builder.CreateGEP( - Type::getInt8Ty(CI->getContext()), element, - Builder.CreateMul( - batchOffset[i - 1], - ConstantInt::get(batchOffset[i - 1]->getType(), v))); - element = Builder.CreateBitCast(element, elementPtrTy); - } else { - return false; - } - } - - if (width > 1) { - res = - res ? Builder.CreateInsertValue(res, element, {v}) - : Builder.CreateInsertValue(UndefValue::get(ArrayType::get( - element->getType(), width)), - element, {v}); - - if (v < width - 1 && !batch) { - ++i; - } - - } else { - res = element; - } - } - - args.push_back(res); - - } else if (ty == BATCH_TYPE::SCALAR) { - args.push_back(res); - } - - truei++; - } - - BATCH_TYPE ret_type = (F->getReturnType()->isVoidTy() || width == 1) - ? BATCH_TYPE::SCALAR - : BATCH_TYPE::VECTOR; - - auto newFunc = Logic.CreateBatch(RequestContext(CI, &Builder), F, width, - arg_types, ret_type); - - if (!newFunc) - return false; - - Value *batch = - Builder.CreateCall(newFunc->getFunctionType(), newFunc, args); - - batch = adaptReturnedVector(CI, batch, Builder, width); - - Value *ret = CI; - Type *retElemType = nullptr; - if (CI->hasStructRetAttr()) { - ret = CI->getArgOperand(0); - retElemType = - CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet) - .getValueAsType(); - } - ReplaceOriginalCall(Builder, ret, retElemType, batch, CI, - DerivativeMode::ForwardMode); - - return true; - } - - bool HandleAutoDiff(Instruction *CI, CallingConv::ID CallingConv, Value *ret, - Type *retElemType, SmallVectorImpl &args, - const std::map &byVal, - const std::vector &constants, Function *fn, - DerivativeMode mode, Options &options, bool sizeOnly, - SmallVectorImpl &calls) { - auto &differet = options.differet; - auto &tape = options.tape; - auto &width = options.width; - auto &allocatedTapeSize = options.allocatedTapeSize; - auto &freeMemory = options.freeMemory; - auto &returnUsed = options.returnUsed; - auto &tapeIsPointer = options.tapeIsPointer; - auto &differentialReturn = options.differentialReturn; - auto &retType = options.retType; - auto &overwritten_args = options.overwritten_args; - auto primalReturn = options.primalReturn; - auto subsequent_calls_may_write = options.subsequent_calls_may_write; - - auto Arch = Triple(CI->getModule()->getTargetTriple()).getArch(); - bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 || - Arch == Triple::amdgcn; - - TypeAnalysis TA(Logic.PPC.FAM); - FnTypeInfo type_args = populate_type_args(TA, fn, mode); - - IRBuilder Builder(CI); - RequestContext context(CI, &Builder); - - // differentiate fn - Function *newFunc = nullptr; - Type *tapeType = nullptr; - const AugmentedReturn *aug; - switch (mode) { - case DerivativeMode::ForwardModeError: - case DerivativeMode::ForwardMode: - if (primalReturn && fn->getReturnType()->isVoidTy()) { - auto fnname = fn->getName(); - EmitFailure("PrimalRetOfVoid", CI->getDebugLoc(), CI, - "Requested primal result of void-returning function type ", - *fn->getFunctionType(), " ", fnname, " ", *CI); - } else - newFunc = Logic.CreateForwardDiff( - context, fn, retType, constants, TA, - /*should return*/ primalReturn, mode, freeMemory, - options.runtimeActivity, width, - /*addedType*/ nullptr, type_args, subsequent_calls_may_write, - overwritten_args, - /*augmented*/ nullptr); - break; - case DerivativeMode::ForwardModeSplit: { - bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1; - aug = &Logic.CreateAugmentedPrimal( - context, fn, retType, constants, TA, - /*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args, - subsequent_calls_may_write, overwritten_args, forceAnonymousTape, - options.runtimeActivity, width, - /*atomicAdd*/ AtomicAdd); - auto &DL = fn->getParent()->getDataLayout(); - if (!forceAnonymousTape) { - assert(!aug->tapeType); - if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) { - auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second; - tapeType = (tapeIdx == -1) - ? aug->fn->getReturnType() - : cast(aug->fn->getReturnType()) - ->getElementType(tapeIdx); - } else { - if (sizeOnly) { - CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false)); - CI->eraseFromParent(); - return true; - } - } - if (sizeOnly) { - auto size = DL.getTypeSizeInBits(tapeType) / 8; - CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size, false)); - CI->eraseFromParent(); - return true; - } - if (tapeType && - DL.getTypeSizeInBits(tapeType) > 8 * (size_t)allocatedTapeSize) { - auto bytes = DL.getTypeSizeInBits(tapeType) / 8; - EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(), - CI, "need ", bytes, " bytes have ", allocatedTapeSize, - " bytes"); - } - } else { - tapeType = getInt8PtrTy(fn->getContext()); - } - newFunc = Logic.CreateForwardDiff( - context, fn, retType, constants, TA, - /*should return*/ primalReturn, mode, freeMemory, - options.runtimeActivity, width, - /*addedType*/ tapeType, type_args, subsequent_calls_may_write, - overwritten_args, aug); - break; - } - case DerivativeMode::ReverseModeCombined: - assert(freeMemory); - newFunc = Logic.CreatePrimalAndGradient( - context, - (ReverseCacheKey){.todiff = fn, - .retType = retType, - .constant_args = constants, - .subsequent_calls_may_write = - subsequent_calls_may_write, - .overwritten_args = overwritten_args, - .returnUsed = primalReturn, - .shadowReturnUsed = false, - .mode = mode, - .width = width, - .freeMemory = freeMemory, - .AtomicAdd = AtomicAdd, - .additionalType = nullptr, - .forceAnonymousTape = false, - .typeInfo = type_args, - .runtimeActivity = options.runtimeActivity}, - TA, /*augmented*/ nullptr); - break; - case DerivativeMode::ReverseModePrimal: - case DerivativeMode::ReverseModeGradient: { - if (primalReturn) { - EmitFailure( - "SplitPrimalRet", CI->getDebugLoc(), CI, - "Option enzyme_primal_return not available in reverse split mode"); - } - bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1; - bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG || - retType == DIFFE_TYPE::DUP_NONEED); - aug = &Logic.CreateAugmentedPrimal( - context, fn, retType, constants, TA, returnUsed, shadowReturnUsed, - type_args, subsequent_calls_may_write, overwritten_args, - forceAnonymousTape, options.runtimeActivity, width, - /*atomicAdd*/ AtomicAdd); - auto &DL = fn->getParent()->getDataLayout(); - if (!forceAnonymousTape) { - assert(!aug->tapeType); - if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) { - auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second; - tapeType = (tapeIdx == -1) - ? aug->fn->getReturnType() - : cast(aug->fn->getReturnType()) - ->getElementType(tapeIdx); - } else { - if (sizeOnly) { - CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false)); - CI->eraseFromParent(); - return true; - } - } - if (sizeOnly) { - auto size = DL.getTypeSizeInBits(tapeType) / 8; - CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size, false)); - CI->eraseFromParent(); - return true; - } - if (tapeType && - DL.getTypeSizeInBits(tapeType) > 8 * (size_t)allocatedTapeSize) { - auto bytes = DL.getTypeSizeInBits(tapeType) / 8; - EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(), - CI, "need ", bytes, " bytes have ", allocatedTapeSize, - " bytes"); - } - } else { - tapeType = getInt8PtrTy(fn->getContext()); - } - if (mode == DerivativeMode::ReverseModePrimal) - newFunc = aug->fn; - else - newFunc = Logic.CreatePrimalAndGradient( - context, - (ReverseCacheKey){.todiff = fn, - .retType = retType, - .constant_args = constants, - .subsequent_calls_may_write = - subsequent_calls_may_write, - .overwritten_args = overwritten_args, - .returnUsed = false, - .shadowReturnUsed = false, - .mode = mode, - .width = width, - .freeMemory = freeMemory, - .AtomicAdd = AtomicAdd, - .additionalType = tapeType, - .forceAnonymousTape = forceAnonymousTape, - .typeInfo = type_args, - .runtimeActivity = options.runtimeActivity}, - TA, aug); - } - } - - if (!newFunc) { - StringRef n = fn->getName(); - EmitFailure("FailedToDifferentiate", fn->getSubprogram(), - &*fn->getEntryBlock().begin(), - "Could not generate derivative function of ", n); - return false; - } - - if (differentialReturn) { - if (differet) - args.push_back(differet); - else if (fn->getReturnType()->isFPOrFPVectorTy()) { - Constant *seed = ConstantFP::get(fn->getReturnType(), 1.0); - if (width == 1) { - args.push_back(seed); - } else { - ArrayType *arrayType = ArrayType::get(fn->getReturnType(), width); - args.push_back(ConstantArray::get( - arrayType, SmallVector(width, seed))); - } - } else if (auto ST = dyn_cast(fn->getReturnType())) { - SmallVector csts; - for (auto e : ST->elements()) { - csts.push_back(ConstantFP::get(e, 1.0)); - } - args.push_back(ConstantStruct::get(ST, csts)); - } else if (auto AT = dyn_cast(fn->getReturnType())) { - SmallVector csts( - AT->getNumElements(), ConstantFP::get(AT->getElementType(), 1.0)); - args.push_back(ConstantArray::get(AT, csts)); - } else { - auto RT = fn->getReturnType(); - EmitFailure("EnzymeCallingError", CI->getDebugLoc(), CI, - "Differential return required for call ", *CI, - " but one of type ", *RT, " could not be auto deduced"); - return false; - } - } - - if ((mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ForwardModeSplit) && - tape && tapeType) { - auto &DL = fn->getParent()->getDataLayout(); - if (tapeIsPointer) { - tape = Builder.CreateBitCast( - tape, PointerType::get( - tapeType, - cast(tape->getType())->getAddressSpace())); - tape = Builder.CreateLoad(tapeType, tape); - } else if (tapeType != tape->getType() && - DL.getTypeSizeInBits(tapeType) <= - DL.getTypeSizeInBits(tape->getType())) { - IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front()); - auto AL = EB.CreateAlloca(tape->getType()); - Builder.CreateStore(tape, AL); - tape = Builder.CreateLoad( - tapeType, - Builder.CreatePointerCast(AL, PointerType::getUnqual(tapeType))); - } - assert(tape->getType() == tapeType); - args.push_back(tape); - } - - if (EnzymePrint) { - llvm::errs() << "postfn:\n" << *newFunc << "\n"; - } - Builder.setFastMathFlags(getFast()); - - // call newFunc with the provided arguments. - if (args.size() != newFunc->getFunctionType()->getNumParams()) { - llvm::errs() << *CI << "\n"; - llvm::errs() << *newFunc << "\n"; - for (auto arg : args) { - llvm::errs() << " + " << *arg << "\n"; - } - auto modestr = to_string(mode); - EmitFailure( - "TooFewArguments", CI->getDebugLoc(), CI, - "Too few arguments passed to __enzyme_autodiff mode=", modestr); - return false; - } - assert(args.size() == newFunc->getFunctionType()->getNumParams()); - for (size_t i = 0; i < args.size(); i++) { - if (args[i]->getType() != newFunc->getFunctionType()->getParamType(i)) { - llvm::errs() << *CI << "\n"; - llvm::errs() << *newFunc << "\n"; - for (auto arg : args) { - llvm::errs() << " + " << *arg << "\n"; - } - auto modestr = to_string(mode); - EmitFailure("BadArgumentType", CI->getDebugLoc(), CI, - "Incorrect argument type passed to __enzyme_autodiff mode=", - modestr, " at index ", i, " expected ", - *newFunc->getFunctionType()->getParamType(i), " found ", - *args[i]->getType()); - return false; - } - } - CallInst *diffretc = cast(Builder.CreateCall(newFunc, args)); - diffretc->setCallingConv(CallingConv); - diffretc->setDebugLoc(CI->getDebugLoc()); - - for (auto &&[attr, ty] : byVal) { - diffretc->addParamAttr( - attr, Attribute::getWithByValType(diffretc->getContext(), ty)); - } - - Value *diffret = diffretc; - if (mode == DerivativeMode::ReverseModePrimal && tape) { - if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) { - auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second; - tapeType = (tapeIdx == -1) ? aug->fn->getReturnType() - : cast(aug->fn->getReturnType()) - ->getElementType(tapeIdx); - unsigned idxs[] = {(unsigned)tapeIdx}; - Value *tapeRes = (tapeIdx == -1) - ? diffret - : Builder.CreateExtractValue(diffret, idxs); - Builder.CreateStore( - tapeRes, - Builder.CreateBitCast( - tape, - PointerType::get( - tapeRes->getType(), - cast(tape->getType())->getAddressSpace()))); - if (tapeIdx != -1) { - auto ST = cast(diffret->getType()); - SmallVector tys(ST->elements().begin(), - ST->elements().end()); - tys.erase(tys.begin()); - auto ST0 = StructType::get(ST->getContext(), tys); - Value *out = UndefValue::get(ST0); - for (unsigned i = 0; i < tys.size(); i++) { - out = Builder.CreateInsertValue( - out, Builder.CreateExtractValue(diffret, {i + 1}), {i}); - } - diffret = out; - } else { - auto ST0 = StructType::get(tape->getContext(), {}); - diffret = UndefValue::get(ST0); - } - } - } - - // Adapt the returned vector type to the struct type expected by our calling - // convention. - if (width > 1 && !diffret->getType()->isEmptyTy() && - !diffret->getType()->isVoidTy() && - (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit)) { - - diffret = adaptReturnedVector(ret, diffret, Builder, width); - } - - ReplaceOriginalCall(Builder, ret, retElemType, diffret, CI, mode); - calls.push_back(diffretc); - return diffret; - } - - /// Return whether successful - bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode, bool sizeOnly, - SmallVectorImpl &calls) { - - // determine function to differentiate - Function *fn = parseFunctionParameter(CI); - if (!fn) - return false; - - IRBuilder<> Builder(CI); - - if (EnzymePrint) - llvm::errs() << "prefn:\n" << *fn << "\n"; - - std::map byVal; - std::vector constants; - SmallVector args; - - auto options = handleArguments(Builder, CI, fn, mode, sizeOnly, constants, - args, byVal); - - if (!options) { - return false; - } - - Value *ret = CI; - Type *retElemType = nullptr; - if (CI->hasStructRetAttr()) { - ret = CI->getArgOperand(0); - retElemType = - CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet) - .getValueAsType(); - } - - return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, - byVal, constants, fn, mode, *options, sizeOnly, - calls); - } - - bool HandleProbProg(CallInst *CI, ProbProgMode mode, - SmallVectorImpl &calls) { - IRBuilder<> Builder(CI); - Function *F = parseFunctionParameter(CI); - if (!F) - return false; - - assert(F); - - std::vector constants; - std::map byVal; - SmallVector args; - - auto diffeMode = DerivativeMode::ReverseModeCombined; - - auto opt = handleArguments(Builder, CI, F, diffeMode, false, constants, - args, byVal); - - SmallVector dargs(args.begin(), args.end()); - -#if LLVM_VERSION_MAJOR >= 16 - if (!opt.has_value()) - return false; -#else - if (!opt.hasValue()) - return false; -#endif - - auto dynamic_interface = opt->dynamic_interface; - auto trace = opt->trace; - auto dtrace = opt->diffeTrace; - auto observations = opt->observations; - auto likelihood = opt->likelihood; - auto dlikelihood = opt->diffeLikelihood; - - // Interface - bool has_dynamic_interface = dynamic_interface != nullptr; - bool needs_interface = - mode == ProbProgMode::Trace || mode == ProbProgMode::Condition; - std::unique_ptr interface; - if (has_dynamic_interface) { - interface = std::make_unique(dynamic_interface, - CI->getFunction()); - } else if (needs_interface) { - interface = std::make_unique(F->getParent()); - } - - // Find sample function - SmallPtrSet sampleFunctions; - SmallPtrSet observeFunctions; - for (auto &func : F->getParent()->functions()) { - if (func.getName().contains("__enzyme_sample")) { - assert(func.getFunctionType()->getNumParams() >= 3); - sampleFunctions.insert(&func); - } else if (func.getName().contains("__enzyme_observe")) { - assert(func.getFunctionType()->getNumParams() >= 3); - observeFunctions.insert(&func); - } - } - - assert(!sampleFunctions.empty() || !observeFunctions.empty()); - - bool autodiff = dtrace || dlikelihood; - IRBuilder<> AllocaBuilder(CI->getParent()->getFirstNonPHI()); - - if (!likelihood) { - likelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(), - nullptr, "likelihood"); - Builder.CreateStore(ConstantFP::getNullValue(Builder.getDoubleTy()), - likelihood); - } - args.push_back(likelihood); - - if (autodiff && !dlikelihood) { - dlikelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(), - nullptr, "dlikelihood"); - Builder.CreateStore(ConstantFP::get(Builder.getDoubleTy(), 1.0), - dlikelihood); - } - - if (autodiff) { - dargs.push_back(likelihood); - dargs.push_back(dlikelihood); - constants.push_back(DIFFE_TYPE::DUP_ARG); - opt->overwritten_args.push_back(false); - } else { - constants.push_back(DIFFE_TYPE::CONSTANT); - opt->overwritten_args.push_back(false); - } - - if (mode == ProbProgMode::Condition) { - opt->overwritten_args.push_back(false); - args.push_back(observations); - dargs.push_back(observations); - constants.push_back(DIFFE_TYPE::CONSTANT); - } - - if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { - opt->overwritten_args.push_back(false); - args.push_back(trace); - dargs.push_back(trace); - constants.push_back(DIFFE_TYPE::CONSTANT); - } - - auto newFunc = Logic.CreateTrace( - RequestContext(CI, &Builder), F, sampleFunctions, observeFunctions, - opt->ActiveRandomVariables, mode, autodiff, interface.get()); - - if (!autodiff) { - auto call = CallInst::Create(newFunc->getFunctionType(), newFunc, args); - ReplaceInstWithInst(CI, call); - return true; - } - - Value *ret = CI; - Type *retElemType = nullptr; - if (CI->hasStructRetAttr()) { - ret = CI->getArgOperand(0); - retElemType = - CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet) - .getValueAsType(); - } - - bool status = HandleAutoDiff( - CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants, - newFunc, DerivativeMode::ReverseModeCombined, *opt, false, calls); - - return status; - } - - bool handleFullModuleTrunc(Function &F) { - if (startsWith(F.getName(), EnzymeFPRTPrefix)) - return false; - typedef std::vector TruncationsTy; - static TruncationsTy FullModuleTruncs = []() -> TruncationsTy { - StringRef ConfigStr(EnzymeTruncateAll); - auto Invalid = [=]() { - // TODO emit better diagnostic - llvm::report_fatal_error("error: invalid format for truncation config"); - }; - - // "64" or "11-52" - auto parseFloatRepr = [&]() -> std::optional { - unsigned Tmp = 0; - if (ConfigStr.consumeInteger(10, Tmp)) - return {}; - if (ConfigStr.consume_front("-")) { - unsigned Tmp2 = 0; - if (ConfigStr.consumeInteger(10, Tmp2)) - Invalid(); - return FloatRepresentation(Tmp, Tmp2); - } - return getDefaultFloatRepr(Tmp); - }; - - // Parse "64to32;32to16;5-10to4-9" - TruncationsTy Tmp; - while (true) { - auto From = parseFloatRepr(); - if (!From && !ConfigStr.empty()) - Invalid(); - if (!From) - break; - if (!ConfigStr.consume_front("to")) - Invalid(); - auto To = parseFloatRepr(); - if (!To) - Invalid(); - Tmp.push_back({*From, *To, TruncOpFullModuleMode}); - ConfigStr.consume_front(";"); - } - return Tmp; - }(); - - if (FullModuleTruncs.empty()) - return false; - - // TODO sort truncations (64to32, then 32to16 will make everything 16) - for (auto Truncation : FullModuleTruncs) { - IRBuilder<> Builder(F.getContext()); - RequestContext context(&*F.getEntryBlock().begin(), &Builder); - Function *TruncatedFunc = Logic.CreateTruncateFunc( - context, &F, Truncation, TruncOpFullModuleMode); - - ValueToValueMapTy Mapping; - for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args())) - Mapping[&TArg] = &Arg; - - // Move the truncated body into the original function - F.deleteBody(); -#if LLVM_VERSION_MAJOR >= 16 - F.splice(F.begin(), TruncatedFunc); -#else - F.getBasicBlockList().splice(F.begin(), - TruncatedFunc->getBasicBlockList()); -#endif - RemapFunction(F, Mapping, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - TruncatedFunc->deleteBody(); - } - return true; - } - - bool lowerEnzymeCalls(Function &F, std::set &done) { - if (done.count(&F)) - return false; - done.insert(&F); - - if (F.empty()) - return false; - - if (handleFullModuleTrunc(F)) - return true; - - bool Changed = false; - - for (BasicBlock &BB : F) - if (InvokeInst *II = dyn_cast(BB.getTerminator())) { - - Function *Fn = II->getCalledFunction(); - - if (auto castinst = dyn_cast(II->getCalledOperand())) { - if (castinst->isCast()) - if (auto fn = dyn_cast(castinst->getOperand(0))) - Fn = fn; - } - if (!Fn) - continue; - - if (!(Fn->getName().contains("__enzyme_float") || - Fn->getName().contains("__enzyme_double") || - Fn->getName().contains("__enzyme_integer") || - Fn->getName().contains("__enzyme_pointer") || - Fn->getName().contains("__enzyme_virtualreverse") || - Fn->getName().contains("__enzyme_call_inactive") || - Fn->getName().contains("__enzyme_autodiff") || - Fn->getName().contains("__enzyme_fwddiff") || - Fn->getName().contains("__enzyme_fwdsplit") || - Fn->getName().contains("__enzyme_augmentfwd") || - Fn->getName().contains("__enzyme_augmentsize") || - Fn->getName().contains("__enzyme_reverse") || - Fn->getName().contains("__enzyme_truncate") || - Fn->getName().contains("__enzyme_batch") || - Fn->getName().contains("__enzyme_error_estimate") || - Fn->getName().contains("__enzyme_trace") || - Fn->getName().contains("__enzyme_condition"))) - continue; - - SmallVector CallArgs(II->arg_begin(), II->arg_end()); - SmallVector OpBundles; - II->getOperandBundlesAsDefs(OpBundles); - // Insert a normal call instruction... - CallInst *NewCall = - CallInst::Create(II->getFunctionType(), II->getCalledOperand(), - CallArgs, OpBundles, "", II); - NewCall->takeName(II); - NewCall->setCallingConv(II->getCallingConv()); - NewCall->setAttributes(II->getAttributes()); - NewCall->setDebugLoc(II->getDebugLoc()); - II->replaceAllUsesWith(NewCall); - - // Insert an unconditional branch to the normal destination. - BranchInst::Create(II->getNormalDest(), II); - - // Remove any PHI node entries from the exception destination. - II->getUnwindDest()->removePredecessor(&BB); - - II->eraseFromParent(); - Changed = true; - } - - MapVector toLower; - MapVector toVirtual; - MapVector toSize; - SmallVector toBatch; - SmallVector toTruncateFuncMem; - SmallVector toTruncateFuncOp; - SmallVector toTruncateValue; - SmallVector toExpandValue; - MapVector toProbProg; - SetVector InactiveCalls; - SetVector IterCalls; - retry:; - for (BasicBlock &BB : F) { - for (Instruction &I : BB) { - CallInst *CI = dyn_cast(&I); - - if (!CI) - continue; - - Function *Fn = nullptr; - - Value *FnOp = CI->getCalledOperand(); - while (true) { - if ((Fn = dyn_cast(FnOp))) - break; - if (auto castinst = dyn_cast(FnOp)) { - if (castinst->isCast()) { - FnOp = castinst->getOperand(0); - continue; - } - } - break; - } - - if (!Fn) - continue; - - size_t num_args = CI->arg_size(); - - if (Fn->getName().contains("__enzyme_todense")) { -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyReadsMemory(); - CI->setOnlyWritesMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); -#endif - } - if (Fn->getName().contains("__enzyme_float")) { -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyReadsMemory(); - CI->setOnlyWritesMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); -#endif - for (size_t i = 0; i < num_args; ++i) { - if (CI->getArgOperand(i)->getType()->isPointerTy()) { - CI->addParamAttr(i, Attribute::ReadNone); - addCallSiteNoCapture(CI, i); - } - } - } - if (Fn->getName().contains("__enzyme_integer")) { -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyReadsMemory(); - CI->setOnlyWritesMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); -#endif - for (size_t i = 0; i < num_args; ++i) { - if (CI->getArgOperand(i)->getType()->isPointerTy()) { - CI->addParamAttr(i, Attribute::ReadNone); - addCallSiteNoCapture(CI, i); - } - } - } - if (Fn->getName().contains("__enzyme_double")) { -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyReadsMemory(); - CI->setOnlyWritesMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); -#endif - for (size_t i = 0; i < num_args; ++i) { - if (CI->getArgOperand(i)->getType()->isPointerTy()) { - CI->addParamAttr(i, Attribute::ReadNone); - addCallSiteNoCapture(CI, i); - } - } - } - if (Fn->getName().contains("__enzyme_pointer")) { -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyReadsMemory(); - CI->setOnlyWritesMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); -#endif - for (size_t i = 0; i < num_args; ++i) { - if (CI->getArgOperand(i)->getType()->isPointerTy()) { - CI->addParamAttr(i, Attribute::ReadNone); - addCallSiteNoCapture(CI, i); - } - } - } - if (Fn->getName().contains("__enzyme_virtualreverse")) { -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyReadsMemory(); - CI->setOnlyWritesMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); -#endif - } - if (Fn->getName().contains("__enzyme_iter")) { -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyReadsMemory(); - CI->setOnlyWritesMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); -#endif - } - if (Fn->getName().contains("__enzyme_call_inactive")) { - InactiveCalls.insert(CI); - } - if (Fn->getName() == "omp_get_max_threads" || - Fn->getName() == "omp_get_thread_num") { -#if LLVM_VERSION_MAJOR >= 16 - Fn->setOnlyAccessesInaccessibleMemory(); - CI->setOnlyAccessesInaccessibleMemory(); - Fn->setOnlyReadsMemory(); - CI->setOnlyReadsMemory(); -#else - Fn->addFnAttr(Attribute::InaccessibleMemOnly); - CI->addAttribute(AttributeList::FunctionIndex, - Attribute::InaccessibleMemOnly); - Fn->addFnAttr(Attribute::ReadOnly); - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly); -#endif - } - if ((Fn->getName() == "cblas_ddot" || Fn->getName() == "cblas_sdot") && - Fn->isDeclaration()) { -#if LLVM_VERSION_MAJOR >= 16 - Fn->setOnlyAccessesArgMemory(); - Fn->setOnlyReadsMemory(); - CI->setOnlyReadsMemory(); -#else - Fn->addFnAttr(Attribute::ArgMemOnly); - Fn->addFnAttr(Attribute::ReadOnly); - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly); -#endif - CI->addParamAttr(1, Attribute::ReadOnly); - addCallSiteNoCapture(CI, 1); - CI->addParamAttr(3, Attribute::ReadOnly); - addCallSiteNoCapture(CI, 3); - } - if (Fn->getName() == "frexp" || Fn->getName() == "frexpf" || - Fn->getName() == "frexpl") { -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyAccessesArgMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ArgMemOnly); -#endif - CI->addParamAttr(1, Attribute::WriteOnly); - } - if (Fn->getName() == "__fd_sincos_1" || Fn->getName() == "__fd_cos_1" || - Fn->getName() == "__mth_i_ipowi") { -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyReadsMemory(); - CI->setOnlyWritesMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); -#endif - } - if (getFuncName(Fn) == "strcmp") { - Fn->addParamAttr(0, Attribute::ReadOnly); - Fn->addParamAttr(1, Attribute::ReadOnly); -#if LLVM_VERSION_MAJOR >= 16 - Fn->setOnlyReadsMemory(); - CI->setOnlyReadsMemory(); -#else - Fn->addFnAttr(Attribute::ReadOnly); - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly); -#endif - } - if (Fn->getName() == "f90io_fmtw_end" || - Fn->getName() == "f90io_unf_end") { -#if LLVM_VERSION_MAJOR >= 16 - Fn->setOnlyAccessesInaccessibleMemory(); - CI->setOnlyAccessesInaccessibleMemory(); -#else - Fn->addFnAttr(Attribute::InaccessibleMemOnly); - CI->addAttribute(AttributeList::FunctionIndex, - Attribute::InaccessibleMemOnly); -#endif - } - if (Fn->getName() == "f90io_open2003a") { -#if LLVM_VERSION_MAJOR >= 16 - Fn->setOnlyAccessesInaccessibleMemOrArgMem(); - CI->setOnlyAccessesInaccessibleMemOrArgMem(); -#else - Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); - CI->addAttribute(AttributeList::FunctionIndex, - Attribute::InaccessibleMemOrArgMemOnly); -#endif - for (size_t i : {0, 1, 2, 3, 4, 5, 6, 7, /*8, */ 9, 10, 11, 12, 13}) { - if (i < num_args && - CI->getArgOperand(i)->getType()->isPointerTy()) { - CI->addParamAttr(i, Attribute::ReadOnly); - } - } - // todo more - for (size_t i : {0, 1}) { - if (i < num_args && - CI->getArgOperand(i)->getType()->isPointerTy()) { - addCallSiteNoCapture(CI, i); - } - } - } - if (Fn->getName() == "f90io_fmtw_inita") { -#if LLVM_VERSION_MAJOR >= 16 - Fn->setOnlyAccessesInaccessibleMemOrArgMem(); - CI->setOnlyAccessesInaccessibleMemOrArgMem(); -#else - Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); - CI->addAttribute(AttributeList::FunctionIndex, - Attribute::InaccessibleMemOrArgMemOnly); -#endif - // todo more - for (size_t i : {0, 2}) { - if (i < num_args && - CI->getArgOperand(i)->getType()->isPointerTy()) { - CI->addParamAttr(i, Attribute::ReadOnly); - } - } - - // todo more - for (size_t i : {0, 2}) { - if (i < num_args && - CI->getArgOperand(i)->getType()->isPointerTy()) { - addCallSiteNoCapture(CI, i); - } - } - } - - if (Fn->getName() == "f90io_unf_init") { -#if LLVM_VERSION_MAJOR >= 16 - Fn->setOnlyAccessesInaccessibleMemOrArgMem(); - CI->setOnlyAccessesInaccessibleMemOrArgMem(); -#else - Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); - CI->addAttribute(AttributeList::FunctionIndex, - Attribute::InaccessibleMemOrArgMemOnly); -#endif - // todo more - for (size_t i : {0, 1, 2, 3}) { - if (i < num_args && - CI->getArgOperand(i)->getType()->isPointerTy()) { - CI->addParamAttr(i, Attribute::ReadOnly); - } - } - - // todo more - for (size_t i : {0, 1, 2, 3}) { - if (i < num_args && - CI->getArgOperand(i)->getType()->isPointerTy()) { - addCallSiteNoCapture(CI, i); - } - } - } - - if (Fn->getName() == "f90io_src_info03a") { -#if LLVM_VERSION_MAJOR >= 16 - Fn->setOnlyAccessesInaccessibleMemOrArgMem(); - CI->setOnlyAccessesInaccessibleMemOrArgMem(); -#else - Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); - CI->addAttribute(AttributeList::FunctionIndex, - Attribute::InaccessibleMemOrArgMemOnly); -#endif - // todo more - for (size_t i : {0, 1}) { - if (i < num_args && - CI->getArgOperand(i)->getType()->isPointerTy()) { - CI->addParamAttr(i, Attribute::ReadOnly); - } - } - - // todo more - for (size_t i : {0}) { - if (i < num_args && - CI->getArgOperand(i)->getType()->isPointerTy()) { - addCallSiteNoCapture(CI, i); - } - } - } - if (Fn->getName() == "f90io_sc_d_fmt_write" || - Fn->getName() == "f90io_sc_i_fmt_write" || - Fn->getName() == "ftnio_fmt_write64" || - Fn->getName() == "f90io_fmt_write64_aa" || - Fn->getName() == "f90io_fmt_writea" || - Fn->getName() == "f90io_unf_writea" || - Fn->getName() == "f90_pausea") { -#if LLVM_VERSION_MAJOR >= 16 - Fn->setOnlyAccessesInaccessibleMemOrArgMem(); - CI->setOnlyAccessesInaccessibleMemOrArgMem(); -#else - Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); - CI->addAttribute(AttributeList::FunctionIndex, - Attribute::InaccessibleMemOrArgMemOnly); -#endif - for (size_t i = 0; i < num_args; ++i) { - if (CI->getArgOperand(i)->getType()->isPointerTy()) { - CI->addParamAttr(i, Attribute::ReadOnly); - addCallSiteNoCapture(CI, i); - } - } - } - - bool enableEnzyme = false; - bool virtualCall = false; - bool sizeOnly = false; - bool batch = false; - bool truncateFuncOp = false; - bool truncateFuncMem = false; - bool truncateValue = false; - bool expandValue = false; - bool probProg = false; - DerivativeMode derivativeMode; - ProbProgMode probProgMode; - if (Fn->getName().contains("__enzyme_autodiff")) { - enableEnzyme = true; - derivativeMode = DerivativeMode::ReverseModeCombined; - } else if (Fn->getName().contains("__enzyme_fwddiff")) { - enableEnzyme = true; - derivativeMode = DerivativeMode::ForwardMode; - } else if (Fn->getName().contains("__enzyme_error_estimate")) { - enableEnzyme = true; - derivativeMode = DerivativeMode::ForwardModeError; - } else if (Fn->getName().contains("__enzyme_fwdsplit")) { - enableEnzyme = true; - derivativeMode = DerivativeMode::ForwardModeSplit; - } else if (Fn->getName().contains("__enzyme_augmentfwd")) { - enableEnzyme = true; - derivativeMode = DerivativeMode::ReverseModePrimal; - } else if (Fn->getName().contains("__enzyme_augmentsize")) { - enableEnzyme = true; - sizeOnly = true; - derivativeMode = DerivativeMode::ReverseModePrimal; - } else if (Fn->getName().contains("__enzyme_reverse")) { - enableEnzyme = true; - derivativeMode = DerivativeMode::ReverseModeGradient; - } else if (Fn->getName().contains("__enzyme_virtualreverse")) { - enableEnzyme = true; - virtualCall = true; - derivativeMode = DerivativeMode::ReverseModeCombined; - } else if (Fn->getName().contains("__enzyme_batch")) { - enableEnzyme = true; - batch = true; - } else if (Fn->getName().contains("__enzyme_truncate_mem_func")) { - enableEnzyme = true; - truncateFuncMem = true; - } else if (Fn->getName().contains("__enzyme_truncate_op_func")) { - enableEnzyme = true; - truncateFuncOp = true; - } else if (Fn->getName().contains("__enzyme_truncate_mem_value")) { - enableEnzyme = true; - truncateValue = true; - } else if (Fn->getName().contains("__enzyme_expand_mem_value")) { - enableEnzyme = true; - expandValue = true; - } else if (Fn->getName().contains("__enzyme_likelihood")) { - enableEnzyme = true; - probProgMode = ProbProgMode::Likelihood; - probProg = true; - } else if (Fn->getName().contains("__enzyme_trace")) { - enableEnzyme = true; - probProgMode = ProbProgMode::Trace; - probProg = true; - } else if (Fn->getName().contains("__enzyme_condition")) { - enableEnzyme = true; - probProgMode = ProbProgMode::Condition; - probProg = true; - } - - if (enableEnzyme) { - - Value *fn = CI->getArgOperand(0); - while (auto ci = dyn_cast(fn)) { - fn = ci->getOperand(0); - } - while (auto ci = dyn_cast(fn)) { - fn = ci->getFunction(); - } - while (auto ci = dyn_cast(fn)) { - fn = ci->getOperand(0); - } - if (auto si = dyn_cast(fn)) { - BasicBlock *post = BB.splitBasicBlock(CI); - BasicBlock *sel1 = BasicBlock::Create(BB.getContext(), "sel1", &F); - BasicBlock *sel2 = BasicBlock::Create(BB.getContext(), "sel2", &F); - BB.getTerminator()->eraseFromParent(); - IRBuilder<> PB(&BB); - PB.CreateCondBr(si->getCondition(), sel1, sel2); - IRBuilder<> S1(sel1); - auto B1 = S1.CreateBr(post); - CallInst *cloned = cast(CI->clone()); - cloned->insertBefore(B1); - cloned->setOperand(0, si->getTrueValue()); - IRBuilder<> S2(sel2); - auto B2 = S2.CreateBr(post); - CI->moveBefore(B2); - CI->setOperand(0, si->getFalseValue()); - if (CI->getNumUses() != 0) { - IRBuilder<> P(post->getFirstNonPHI()); - auto merge = P.CreatePHI(CI->getType(), 2); - merge->addIncoming(cloned, sel1); - merge->addIncoming(CI, sel2); - CI->replaceAllUsesWith(merge); - } - goto retry; - } - if (virtualCall) - toVirtual[CI] = derivativeMode; - else if (sizeOnly) - toSize[CI] = derivativeMode; - else if (batch) - toBatch.push_back(CI); - else if (truncateFuncOp) - toTruncateFuncOp.push_back(CI); - else if (truncateFuncMem) - toTruncateFuncMem.push_back(CI); - else if (truncateValue) - toTruncateValue.push_back(CI); - else if (expandValue) - toExpandValue.push_back(CI); - else if (probProg) { - toProbProg[CI] = probProgMode; - } else - toLower[CI] = derivativeMode; - - if (auto dc = dyn_cast(fn)) { - // Force postopt on any inner functions in the nested - // AD case. - bool tmp = Logic.PostOpt; - Logic.PostOpt = true; - Changed |= lowerEnzymeCalls(*dc, done); - Logic.PostOpt = tmp; - } - } - } - } - - for (auto CI : InactiveCalls) { - IRBuilder<> B(CI); - Value *fn = CI->getArgOperand(0); - SmallVector Args; - SmallVector ArgTypes; - for (size_t i = 1; i < CI->arg_size(); ++i) { - Args.push_back(CI->getArgOperand(i)); - ArgTypes.push_back(CI->getArgOperand(i)->getType()); - } - auto FT = FunctionType::get(CI->getType(), ArgTypes, /*varargs*/ false); - if (fn->getType() != FT) { - fn = B.CreatePointerCast(fn, PointerType::getUnqual(FT)); - } - auto Rep = B.CreateCall(FT, fn, Args); - Rep->addAttribute(AttributeList::FunctionIndex, - Attribute::get(Rep->getContext(), "enzyme_inactive")); - CI->replaceAllUsesWith(Rep); - CI->eraseFromParent(); - Changed = true; - } - - SmallVector calls; - - // Perform all the size replacements first to create constants - for (auto pair : toSize) { - bool successful = HandleAutoDiffArguments(pair.first, pair.second, - /*sizeOnly*/ true, calls); - Changed = true; - if (!successful) - break; - } - for (auto pair : toLower) { - bool successful = HandleAutoDiffArguments(pair.first, pair.second, - /*sizeOnly*/ false, calls); - Changed = true; - if (!successful) - break; - } - - for (auto pair : toVirtual) { - auto CI = pair.first; - Constant *fn = dyn_cast(CI->getArgOperand(0)); - if (!fn) { - EmitFailure("IllegalVirtual", CI->getDebugLoc(), CI, - "Cannot create virtual version of non-constant value ", *CI, - *CI->getArgOperand(0)); - return false; - } - TypeAnalysis TA(Logic.PPC.FAM); - - auto Arch = - llvm::Triple( - CI->getParent()->getParent()->getParent()->getTargetTriple()) - .getArch(); - - bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 || - Arch == Triple::amdgcn; - - IRBuilder<> Builder(CI); - auto val = GradientUtils::GetOrCreateShadowConstant( - RequestContext(CI, &Builder), Logic, - Logic.PPC.FAM.getResult(F), TA, fn, - pair.second, /*runtimeActivity*/ false, /*width*/ 1, AtomicAdd); - CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType())); - CI->eraseFromParent(); - Changed = true; - } - - for (auto call : toBatch) { - HandleBatch(call); - } - for (auto call : toTruncateFuncMem) { - HandleTruncateFunc(call, TruncMemMode); - } - for (auto call : toTruncateFuncOp) { - HandleTruncateFunc(call, TruncOpMode); - } - for (auto call : toTruncateValue) { - HandleTruncateValue(call, true); - } - for (auto call : toExpandValue) { - HandleTruncateValue(call, false); - } - - for (auto &&[call, mode] : toProbProg) { - HandleProbProg(call, mode, calls); - } - - if (Logic.PostOpt) { - auto Params = llvm::getInlineParams(); - - llvm::SetVector Q; - for (auto call : calls) - Q.insert(call); - while (Q.size()) { - auto cur = *Q.begin(); - Function *outerFunc = cur->getParent()->getParent(); - llvm::OptimizationRemarkEmitter ORE(outerFunc); - Q.erase(Q.begin()); - if (auto F = cur->getCalledFunction()) { - if (!F->empty()) { - // Garbage collect AC's created - SmallVector, 2> ACAlloc; - auto getAC = [&](Function &F) -> llvm::AssumptionCache & { - auto AC = std::make_unique(F); - ACAlloc.push_back(std::move(AC)); - return *ACAlloc.back(); - }; - auto GetTLI = - [&](llvm::Function &F) -> const llvm::TargetLibraryInfo & { - return Logic.PPC.FAM.getResult(F); - }; - - TargetTransformInfo TTI(F->getParent()->getDataLayout()); - auto GetInlineCost = [&](CallBase &CB) { - auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI); - return cst; - }; -#if LLVM_VERSION_MAJOR >= 20 - if (llvm::shouldInline(*cur, TTI, GetInlineCost, ORE)) -#else - if (llvm::shouldInline(*cur, GetInlineCost, ORE)) -#endif - { - InlineFunctionInfo IFI; - InlineResult IR = InlineFunction(*cur, IFI); - if (IR.isSuccess()) { - LowerSparsification(outerFunc, /*replaceAll*/ false); - for (auto U : outerFunc->users()) { - if (auto CI = dyn_cast(U)) { - if (CI->getCalledFunction() == outerFunc) { - Q.insert(CI); - } - } - } - } - } - } - } - } - } - - if (Changed && EnzymeAttributor) { - // TODO consider enabling when attributor does not delete - // dead internal functions, which invalidates Enzyme's cache - // code left here to re-enable upon Attributor patch - -#if !defined(FLANG) && !defined(ROCM) - - AnalysisGetter AG(Logic.PPC.FAM); - SetVector Functions; - for (Function &F2 : *F.getParent()) { - Functions.insert(&F2); - } - - CallGraphUpdater CGUpdater; - BumpPtrAllocator Allocator; - InformationCache InfoCache(*F.getParent(), AG, Allocator, - /* CGSCC */ nullptr); - - DenseSet Allowed = { - &AAHeapToStack::ID, - &AANoCapture::ID, - - &AAMemoryBehavior::ID, - &AAMemoryLocation::ID, - &AANoUnwind::ID, - &AANoSync::ID, - &AANoRecurse::ID, - &AAWillReturn::ID, - &AANoReturn::ID, - &AANonNull::ID, - &AANoAlias::ID, - &AADereferenceable::ID, - &AAAlign::ID, -#if LLVM_VERSION_MAJOR < 17 - &AAReturnedValues::ID, +using namespace llvm; +#ifdef DEBUG_TYPE +#undef DEBUG_TYPE #endif - &AANoFree::ID, - &AANoUndef::ID, +#define DEBUG_TYPE "lower-reactant-intrinsic" - //&AAValueSimplify::ID, - //&AAReachability::ID, - //&AAValueConstantRange::ID, - //&AAUndefinedBehavior::ID, - //&AAPotentialValues::ID, - }; - - AttributorConfig aconfig(CGUpdater); - aconfig.Allowed = &Allowed; - aconfig.DeleteFns = false; - Attributor A(Functions, InfoCache, aconfig); - for (Function *F : Functions) { - // Populate the Attributor with abstract attribute opportunities in - // the function and the information cache with IR information. - A.identifyDefaultAbstractAttributes(*F); - } - A.run(); -#endif - } - return Changed; +class ReactantBase { +public: + ReactantBase(bool PostOpt) { } bool run(Module &M) { - Logic.clear(); + bool changed = true; for (Function &F : make_early_inc_range(M)) { - attributeKnownFunctions(F); - } - - bool changed = false; - for (Function &F : M) { - if (F.empty()) - continue; - for (BasicBlock &BB : F) { - for (Instruction &I : make_early_inc_range(BB)) { - if (auto CI = dyn_cast(&I)) { - Function *F = CI->getCalledFunction(); - if (auto castinst = - dyn_cast(CI->getCalledOperand())) { - if (castinst->isCast()) - if (auto fn = dyn_cast(castinst->getOperand(0))) { - F = fn; - } - } - if (F && F->getName() == "f90_mzero8") { - IRBuilder<> B(CI); - - Value *args[3]; - args[0] = CI->getArgOperand(0); - args[1] = ConstantInt::get(Type::getInt8Ty(M.getContext()), 0); - args[2] = B.CreateMul( - CI->getArgOperand(1), - ConstantInt::get(CI->getArgOperand(1)->getType(), 8)); - B.CreateMemSet(args[0], args[1], args[2], MaybeAlign()); - - CI->eraseFromParent(); - } - } - } - } - } - - if (Logic.PostOpt && EnzymeOMPOpt) { - OpenMPOptPass().run(M, Logic.PPC.MAM); - /// Attributor is run second time for promoted args to get attributes. - AttributorPass().run(M, Logic.PPC.MAM); - for (auto &F : M) - if (!F.empty()) - PromotePass().run(F, Logic.PPC.FAM); - changed = true; - } - - std::set done; - for (Function &F : M) { - if (F.empty()) - continue; - - changed |= lowerEnzymeCalls(F, done); - } - - for (Function &F : M) { - if (F.empty()) - continue; - - for (BasicBlock &BB : F) { - for (Instruction &I : make_early_inc_range(BB)) { - if (auto CI = dyn_cast(&I)) { - Function *F = CI->getCalledFunction(); - if (auto castinst = - dyn_cast(CI->getCalledOperand())) { - if (castinst->isCast()) - if (auto fn = dyn_cast(castinst->getOperand(0))) { - F = fn; - } - } - if (F) { - if (F->getName().contains("__enzyme_float") || - F->getName().contains("__enzyme_double") || - F->getName().contains("__enzyme_integer") || - F->getName().contains("__enzyme_pointer")) { - CI->eraseFromParent(); - changed = true; - } - if (F->getName() == "__enzyme_iter") { - CI->replaceAllUsesWith(CI->getArgOperand(0)); - CI->eraseFromParent(); - changed = true; - } - } - } - } - } - } - - SmallPtrSet sample_calls; - SmallPtrSet observe_calls; - for (auto &&func : M) { - for (auto &&BB : func) { - for (auto &&Inst : BB) { - if (auto CI = dyn_cast(&Inst)) { - Function *fun = CI->getCalledFunction(); - if (!fun) - continue; - - if (fun->getName().contains("__enzyme_sample")) { - if (CI->getNumOperands() < 3) { - EmitFailure( - "IllegalNumberOfArguments", CI->getDebugLoc(), CI, - "Not enough arguments passed to call to __enzyme_sample"); - } - Function *samplefn = GetFunctionFromValue(CI->getOperand(0)); - unsigned expected = - samplefn->getFunctionType()->getNumParams() + 3; - unsigned actual = CI->arg_size(); - if (actual - 3 != samplefn->getFunctionType()->getNumParams()) { - EmitFailure("IllegalNumberOfArguments", CI->getDebugLoc(), CI, - "Illegal number of arguments passed to call to " - "__enzyme_sample.", - " Expected: ", expected, " got: ", actual); - } - Function *pdf = GetFunctionFromValue(CI->getArgOperand(1)); - - for (unsigned i = 0; - i < samplefn->getFunctionType()->getNumParams(); ++i) { - Value *ci_arg = CI->getArgOperand(i + 3); - Value *sample_arg = samplefn->arg_begin() + i; - Value *pdf_arg = pdf->arg_begin() + i; - - if (ci_arg->getType() != sample_arg->getType()) { - EmitFailure( - "IllegalSampleType", CI->getDebugLoc(), CI, - "Type of: ", *ci_arg, " (", *ci_arg->getType(), ")", - " does not match the argument type of the sample " - "function: ", - *samplefn, " at: ", i, " (", *sample_arg->getType(), ")"); - } - if (ci_arg->getType() != pdf_arg->getType()) { - EmitFailure("IllegalSampleType", CI->getDebugLoc(), CI, - "Type of: ", *ci_arg, " (", *ci_arg->getType(), - ")", - " does not match the argument type of the " - "density function: ", - *pdf, " at: ", i, " (", *pdf_arg->getType(), ")"); - } - } + if (!F.empty()) continue; + if (F.getName() == "cudaMalloc") { + auto entry = BasicBlock::Create(F.getContext(), "entry", &F); + IRBuilder<> B(entry); - if ((pdf->arg_end() - 1)->getType() != - samplefn->getReturnType()) { - EmitFailure( - "IllegalSampleType", CI->getDebugLoc(), CI, - "Return type of ", *samplefn, " (", - *samplefn->getReturnType(), ")", - " does not match the last argument type of the density " - "function: ", - *pdf, " (", *(pdf->arg_end() - 1)->getType(), ")"); - } - sample_calls.insert(CI); - - } else if (fun->getName().contains("__enzyme_observe")) { - if (CI->getNumOperands() < 3) { - EmitFailure( - "IllegalNumberOfArguments", CI->getDebugLoc(), CI, - "Not enough arguments passed to call to __enzyme_sample"); - } - Value *observed = CI->getOperand(0); - Function *pdf = GetFunctionFromValue(CI->getArgOperand(1)); - unsigned expected = pdf->getFunctionType()->getNumParams() - 1; - - unsigned actual = CI->arg_size(); - if (actual - 3 != expected) { - EmitFailure("IllegalNumberOfArguments", CI->getDebugLoc(), CI, - "Illegal number of arguments passed to call to " - "__enzyme_observe.", - " Expected: ", expected, " got: ", actual); - } - - for (unsigned i = 0; - i < pdf->getFunctionType()->getNumParams() - 1; ++i) { - Value *ci_arg = CI->getArgOperand(i + 3); - Value *pdf_arg = pdf->arg_begin() + i; - - if (ci_arg->getType() != pdf_arg->getType()) { - EmitFailure("IllegalSampleType", CI->getDebugLoc(), CI, - "Type of: ", *ci_arg, " (", *ci_arg->getType(), - ")", - " does not match the argument type of the " - "density function: ", - *pdf, " at: ", i, " (", *pdf_arg->getType(), ")"); - } - } - - if ((pdf->arg_end() - 1)->getType() != observed->getType()) { - EmitFailure( - "IllegalSampleType", CI->getDebugLoc(), CI, - "Return type of ", *observed, " (", *observed->getType(), - ")", - " does not match the last argument type of the density " - "function: ", - *pdf, " (", *(pdf->arg_end() - 1)->getType(), ")"); - } - observe_calls.insert(CI); - } - } - } - } - } - - // Replace calls to __enzyme_sample with the actual sample calls after - // running prob prog - for (auto call : sample_calls) { - Function *samplefn = GetFunctionFromValue(call->getArgOperand(0)); - - SmallVector args; - for (auto it = call->arg_begin() + 3; it != call->arg_end(); it++) { - args.push_back(*it); - } - CallInst *choice = - CallInst::Create(samplefn->getFunctionType(), samplefn, args); - - ReplaceInstWithInst(call, choice); - } - - for (auto call : observe_calls) { - Value *observed = call->getArgOperand(0); - - if (!call->getType()->isVoidTy()) - call->replaceAllUsesWith(observed); - call->eraseFromParent(); - } - - for (const auto &pair : Logic.PPC.cache) - pair.second->eraseFromParent(); - Logic.clear(); - - if (changed && Logic.PostOpt) { - TimeTraceScope timeScope("Enzyme PostOpt", M.getName()); - - PassBuilder PB; - LoopAnalysisManager LAM; - FunctionAnalysisManager FAM; - CGSCCAnalysisManager CGAM; - ModuleAnalysisManager MAM; - PB.registerModuleAnalyses(MAM); - PB.registerFunctionAnalyses(FAM); - PB.registerLoopAnalyses(LAM); - PB.registerCGSCCAnalyses(CGAM); - PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); - auto PM = PB.buildModuleSimplificationPipeline(OptimizationLevel::O2, - ThinOrFullLTOPhase::None); - PM.run(M, MAM); - if (EnzymeOMPOpt) { - OpenMPOptPass().run(M, MAM); - /// Attributor is run second time for promoted args to get attributes. - AttributorPass().run(M, MAM); - for (auto &F : M) - if (!F.empty()) - PromotePass().run(F, FAM); + auto entry = new BasicBlock() + F.ad } } - for (auto &F : M) { - if (!F.empty()) - changed |= LowerSparsification(&F); - } return changed; } }; -class EnzymeOldPM : public EnzymeBase, public ModulePass { +class ReactantOldPM : public ReactantBase, public ModulePass { public: static char ID; - EnzymeOldPM(bool PostOpt = false) : EnzymeBase(PostOpt), ModulePass(ID) {} + EnzymeOldPM(bool PostOpt = false) : ReactantBase(PostOpt), ModulePass(ID) {} void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); @@ -3274,36 +123,36 @@ class EnzymeOldPM : public EnzymeBase, public ModulePass { } // namespace -char EnzymeOldPM::ID = 0; +char ReactantOldPM::ID = 0; -static RegisterPass X("enzyme", "Enzyme Pass"); +static RegisterPass X("enzyme", "Enzyme Pass"); -ModulePass *createEnzymePass(bool PostOpt) { return new EnzymeOldPM(PostOpt); } +ModulePass *createReactantPass(bool PostOpt) { return new EnzymeOldPM(PostOpt); } #include #include #include "llvm/IR/LegacyPassManager.h" -extern "C" void AddEnzymePass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createEnzymePass(/*PostOpt*/ false)); +extern "C" void AddReactantPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createReactantPass(/*PostOpt*/ false)); } #include "llvm/Passes/PassPlugin.h" -class EnzymeNewPM final : public EnzymeBase, - public AnalysisInfoMixin { - friend struct llvm::AnalysisInfoMixin; +class ReactantNewPM final : public ReactantBase, + public AnalysisInfoMixin { + friend struct llvm::AnalysisInfoMixin; private: static llvm::AnalysisKey Key; public: using Result = llvm::PreservedAnalyses; - EnzymeNewPM(bool PostOpt = false) : EnzymeBase(PostOpt) {} + ReactantNewPM(bool PostOpt = false) : ReactantBase(PostOpt) {} Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { - return EnzymeBase::run(M) ? PreservedAnalyses::none() + return ReactantBase::run(M) ? PreservedAnalyses::none() : PreservedAnalyses::all(); } @@ -3311,7 +160,7 @@ class EnzymeNewPM final : public EnzymeBase, }; #undef DEBUG_TYPE -AnalysisKey EnzymeNewPM::Key; +AnalysisKey ReactantNewPM::Key; #include "ActivityAnalysisPrinter.h" #include "JLInstSimplify.h" @@ -3355,55 +204,7 @@ AnalysisKey EnzymeNewPM::Key; #include "llvm/Transforms/Scalar/LoopFlatten.h" #include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h" -static InlineParams getInlineParamsFromOptLevel(OptimizationLevel Level) { - return getInlineParams(Level.getSpeedupLevel(), Level.getSizeLevel()); -} - -#include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h" -#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" -namespace llvm { -extern cl::opt SetLicmMssaNoAccForPromotionCap; -extern cl::opt SetLicmMssaOptCap; -#define EnableLoopFlatten false -#define EagerlyInvalidateAnalyses false -#define RunNewGVN false -#define EnableConstraintElimination true -#define UseInlineAdvisor InliningAdvisorMode::Default -#define EnableMemProfContextDisambiguation false -// extern cl::opt EnableMatrix; -#define EnableMatrix false -#define EnableModuleInliner false -} // namespace llvm - void augmentPassBuilder(llvm::PassBuilder &PB) { - - auto prePass = [](ModulePassManager &MPM, OptimizationLevel Level) { - FunctionPassManager OptimizePM; - OptimizePM.addPass(Float2IntPass()); - OptimizePM.addPass(LowerConstantIntrinsicsPass()); - - if (EnableMatrix) { - OptimizePM.addPass(LowerMatrixIntrinsicsPass()); - OptimizePM.addPass(EarlyCSEPass()); - } - - LoopPassManager LPM; - bool LTOPreLink = false; - // First rotate loops that may have been un-rotated by prior passes. - // Disable header duplication at -Oz. - LPM.addPass(LoopRotatePass(Level != OptimizationLevel::Oz, LTOPreLink)); - // Some loops may have become dead by now. Try to delete them. - // FIXME: see discussion in https://reviews.llvm.org/D112851, - // this may need to be revisited once we run GVN before - // loop deletion in the simplification pipeline. - LPM.addPass(LoopDeletionPass()); - - LPM.addPass(llvm::LoopFullUnrollPass()); - OptimizePM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM))); - - MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizePM))); - }; - #if LLVM_VERSION_MAJOR >= 20 auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level, ThinOrFullLTOPhase) @@ -3411,288 +212,14 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level) #endif { - MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); - - if (!EnzymeEnable) - return; - - if (Level != OptimizationLevel::O0) - prePass(MPM, Level); - MPM.addPass(llvm::AlwaysInlinerPass()); - FunctionPassManager OptimizerPM; - FunctionPassManager OptimizerPM2; -#if LLVM_VERSION_MAJOR >= 16 - OptimizerPM.addPass(llvm::GVNPass()); - OptimizerPM.addPass(llvm::SROAPass(llvm::SROAOptions::PreserveCFG)); -#else - OptimizerPM.addPass(llvm::GVNPass()); - OptimizerPM.addPass(llvm::SROAPass()); -#endif - MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM))); - MPM.addPass(EnzymeNewPM(/*PostOpt=*/true)); - MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false)); -#if LLVM_VERSION_MAJOR >= 16 - OptimizerPM2.addPass(llvm::GVNPass()); - OptimizerPM2.addPass(llvm::SROAPass(llvm::SROAOptions::PreserveCFG)); -#else - OptimizerPM2.addPass(llvm::GVNPass()); - OptimizerPM2.addPass(llvm::SROAPass()); -#endif - - LoopPassManager LPM1; - LPM1.addPass(LoopDeletionPass()); - OptimizerPM2.addPass(createFunctionToLoopPassAdaptor(std::move(LPM1))); - - MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM2))); - MPM.addPass(GlobalOptPass()); + MPM.addPass(ReactantNewPM()); }; + // TODO need for perf reasons to move Enzyme pass to the pre vectorization. PB.registerOptimizerEarlyEPCallback(loadPass); - auto loadNVVM = [](ModulePassManager &MPM, OptimizationLevel) { - MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); - }; - - // We should register at vectorizer start for consistency, however, - // that requires a functionpass, and we have a modulepass. - // PB.registerVectorizerStartEPCallback(loadPass); - PB.registerPipelineStartEPCallback(loadNVVM); - PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadNVVM); - - auto preLTOPass = [](ModulePassManager &MPM, OptimizationLevel Level) { - // Create a function that performs CFI checks for cross-DSO calls with - // targets in the current module. - MPM.addPass(CrossDSOCFIPass()); - - if (Level == OptimizationLevel::O0) { - return; - } - - // Try to run OpenMP optimizations, quick no-op if no OpenMP metadata - // present. -#if LLVM_VERSION_MAJOR >= 16 - MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink)); -#else - MPM.addPass(OpenMPOptPass()); -#endif - - // Remove unused virtual tables to improve the quality of code - // generated by whole-program devirtualization and bitset lowering. - MPM.addPass(GlobalDCEPass()); - - // Do basic inference of function attributes from known properties of - // system libraries and other oracles. - MPM.addPass(InferFunctionAttrsPass()); - - if (Level.getSpeedupLevel() > 1) { - MPM.addPass(createModuleToFunctionPassAdaptor(CallSiteSplittingPass(), - EagerlyInvalidateAnalyses)); - - // Indirect call promotion. This should promote all the targets that - // are left by the earlier promotion pass that promotes intra-module - // targets. This two-step promotion is to save the compile time. For - // LTO, it should produce the same result as if we only do promotion - // here. - // MPM.addPass(PGOIndirectCallPromotion( - // true /* InLTO */, PGOOpt && PGOOpt->Action == - // PGOOptions::SampleUse)); - - // Propagate constants at call sites into the functions they call. - // This opens opportunities for globalopt (and inlining) by - // substituting function pointers passed as arguments to direct uses - // of functions. -#if LLVM_VERSION_MAJOR >= 16 - MPM.addPass(IPSCCPPass(IPSCCPOptions(/*AllowFuncSpec=*/ - Level != OptimizationLevel::Os && - Level != OptimizationLevel::Oz))); -#else - MPM.addPass(IPSCCPPass()); -#endif - - // Attach metadata to indirect call sites indicating the set of - // functions they may target at run-time. This should follow IPSCCP. - MPM.addPass(CalledValuePropagationPass()); - } - - // Now deduce any function attributes based in the current code. - MPM.addPass( - createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass())); - - // Do RPO function attribute inference across the module to - // forward-propagate attributes where applicable. - // FIXME: Is this really an optimization rather than a - // canonicalization? - MPM.addPass(ReversePostOrderFunctionAttrsPass()); - - // Use in-range annotations on GEP indices to split globals where - // beneficial. - MPM.addPass(GlobalSplitPass()); - - // Run whole program optimization of virtual call when the list of - // callees is fixed. MPM.addPass(WholeProgramDevirtPass(ExportSummary, - // nullptr)); - - // Stop here at -O1. - if (Level == OptimizationLevel::O1) { - return; - } - - // Optimize globals to try and fold them into constants. - MPM.addPass(GlobalOptPass()); - - // Promote any localized globals to SSA registers. - MPM.addPass(createModuleToFunctionPassAdaptor(PromotePass())); - - // Linking modules together can lead to duplicate global constant, - // only keep one copy of each constant. - MPM.addPass(ConstantMergePass()); - - // Remove unused arguments from functions. - MPM.addPass(DeadArgumentEliminationPass()); - - // Reduce the code after globalopt and ipsccp. Both can open up - // significant simplification opportunities, and both can propagate - // functions through function pointers. When this happens, we often - // have to resolve varargs calls, etc, so let instcombine do this. - FunctionPassManager PeepholeFPM; - PeepholeFPM.addPass(InstCombinePass()); - if (Level.getSpeedupLevel() > 1) - PeepholeFPM.addPass(AggressiveInstCombinePass()); - - MPM.addPass(createModuleToFunctionPassAdaptor(std::move(PeepholeFPM), - EagerlyInvalidateAnalyses)); - - // Note: historically, the PruneEH pass was run first to deduce - // nounwind and generally clean up exception handling overhead. It - // isn't clear this is valuable as the inliner doesn't currently care - // whether it is inlining an invoke or a call. Run the inliner now. - if (EnableModuleInliner) { - MPM.addPass(ModuleInlinerPass(getInlineParamsFromOptLevel(Level), - UseInlineAdvisor, - ThinOrFullLTOPhase::FullLTOPostLink)); - } else { - MPM.addPass(ModuleInlinerWrapperPass( - getInlineParamsFromOptLevel(Level), - /* MandatoryFirst */ true, - InlineContext{ThinOrFullLTOPhase::FullLTOPostLink, - InlinePass::CGSCCInliner})); - } - - // Perform context disambiguation after inlining, since that would - // reduce the amount of additional cloning required to distinguish the - // allocation contexts. if (EnableMemProfContextDisambiguation) - // MPM.addPass(MemProfContextDisambiguation()); - - // Optimize globals again after we ran the inliner. - MPM.addPass(GlobalOptPass()); - - // Run the OpenMPOpt pass again after global optimizations. -#if LLVM_VERSION_MAJOR >= 16 - MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink)); -#else - MPM.addPass(OpenMPOptPass()); -#endif - - // Garbage collect dead functions. - MPM.addPass(GlobalDCEPass()); - - // If we didn't decide to inline a function, check to see if we can - // transform it to pass arguments by value instead of by reference. - MPM.addPass( - createModuleToPostOrderCGSCCPassAdaptor(ArgumentPromotionPass())); - - FunctionPassManager FPM; - // The IPO Passes may leave cruft around. Clean up after them. - FPM.addPass(InstCombinePass()); - - if (EnableConstraintElimination) - FPM.addPass(ConstraintEliminationPass()); - - FPM.addPass(JumpThreadingPass()); - - // Do a post inline PGO instrumentation and use pass. This is a context - // sensitive PGO pass. -#if 0 - if (PGOOpt) { - if (PGOOpt->CSAction == PGOOptions::CSIRInstr) - addPGOInstrPasses(MPM, Level, /* RunProfileGen */ true, - /* IsCS */ true, PGOOpt->CSProfileGenFile, - PGOOpt->ProfileRemappingFile, - ThinOrFullLTOPhase::FullLTOPostLink, PGOOpt->FS); - else if (PGOOpt->CSAction == PGOOptions::CSIRUse) - addPGOInstrPasses(MPM, Level, /* RunProfileGen */ false, - /* IsCS */ true, PGOOpt->ProfileFile, - PGOOpt->ProfileRemappingFile, - ThinOrFullLTOPhase::FullLTOPostLink, PGOOpt->FS); - } -#endif - - // Break up allocas -#if LLVM_VERSION_MAJOR >= 16 - FPM.addPass(SROAPass(SROAOptions::ModifyCFG)); -#else - FPM.addPass(SROAPass()); -#endif - - // LTO provides additional opportunities for tailcall elimination due - // to link-time inlining, and visibility of nocapture attribute. - FPM.addPass(TailCallElimPass()); - - // Run a few AA driver optimizations here and now to cleanup the code. - MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM), - EagerlyInvalidateAnalyses)); - - MPM.addPass( - createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass())); - - // Require the GlobalsAA analysis for the module so we can query it - // within MainFPM. - MPM.addPass(RequireAnalysisPass()); - }; - auto loadLTO = [preLTOPass, loadPass](ModulePassManager &MPM, OptimizationLevel Level) { - preLTOPass(MPM, Level); - MPM.addPass( - createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass())); - - // Require the GlobalsAA analysis for the module so we can query it - // within MainFPM. - MPM.addPass(RequireAnalysisPass()); - - // Invalidate AAManager so it can be recreated and pick up the newly - // available GlobalsAA. - MPM.addPass( - createModuleToFunctionPassAdaptor(InvalidateAnalysisPass())); - - FunctionPassManager MainFPM; - MainFPM.addPass(createFunctionToLoopPassAdaptor( - LICMPass(SetLicmMssaOptCap, SetLicmMssaNoAccForPromotionCap, - /*AllowSpeculation=*/true), - /*USeMemorySSA=*/true, /*UseBlockFrequencyInfo=*/false)); - - if (RunNewGVN) - MainFPM.addPass(NewGVNPass()); - else - MainFPM.addPass(GVNPass()); - - // Remove dead memcpy()'s. - MainFPM.addPass(MemCpyOptPass()); - - // Nuke dead stores. - MainFPM.addPass(DSEPass()); -#if LLVM_VERSION_MAJOR >= 17 - MainFPM.addPass(MoveAutoInitPass()); -#endif - MainFPM.addPass(MergedLoadStoreMotionPass()); - - LoopPassManager LPM; - if (EnableLoopFlatten && Level.getSpeedupLevel() > 1) - LPM.addPass(LoopFlattenPass()); - LPM.addPass(IndVarSimplifyPass()); - LPM.addPass(LoopDeletionPass()); - // FIXME: Add loop interchange. - #if LLVM_VERSION_MAJOR >= 20 loadPass(MPM, Level, ThinOrFullLTOPhase::None); #else @@ -3702,52 +229,25 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO); } -extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB, +extern "C" void registerReactantAndPassPipeline(llvm::PassBuilder &PB, bool augment = false) { - if (augment) { - augmentPassBuilder(PB); - } +} + +extern "C" void registerReactant(llvm::PassBuilder &PB) { + PB.registerPipelineParsingCallback( [](llvm::StringRef Name, llvm::ModulePassManager &MPM, llvm::ArrayRef) { - if (Name == "enzyme") { - MPM.addPass(EnzymeNewPM()); - return true; - } - if (Name == "preserve-nvvm") { - MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); - return true; - } - if (Name == "print-type-analysis") { - MPM.addPass(TypeAnalysisPrinterNewPM()); - return true; - } - return false; - }); - PB.registerPipelineParsingCallback( - [](llvm::StringRef Name, llvm::FunctionPassManager &FPM, - llvm::ArrayRef) { - if (Name == "print-activity-analysis") { - FPM.addPass(ActivityAnalysisPrinterNewPM()); - return true; - } - if (Name == "jl-inst-simplify") { - FPM.addPass(JLInstSimplifyNewPM()); + if (Name == "reactant") { + MPM.addPass(ReactantNewPM()); return true; } return false; }); -} - -extern "C" void registerEnzyme(llvm::PassBuilder &PB) { -#ifdef ENZYME_RUNPASS - registerEnzymeAndPassPipeline(PB, /*augment*/ true); -#else - registerEnzymeAndPassPipeline(PB, /*augment*/ false); -#endif + registerReactantAndPassPipeline(PB, /*augment*/ false); } extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK llvmGetPassPluginInfo() { - return {LLVM_PLUGIN_API_VERSION, "EnzymeNewPM", "v0.1", registerEnzyme}; + return {LLVM_PLUGIN_API_VERSION, "ReactantNewPM", "v0.1", registerReactant}; } From 745f4d40806f8457b61bd83abc8366f794ba0fdd Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 9 Jun 2025 22:04:07 -0500 Subject: [PATCH 02/13] Raising to unified gpu --- enzyme/Enzyme/ActivityAnalysis.cpp | 3497 ------ enzyme/Enzyme/ActivityAnalysis.h | 280 - enzyme/Enzyme/ActivityAnalysisPrinter.cpp | 214 - enzyme/Enzyme/ActivityAnalysisPrinter.h | 54 - enzyme/Enzyme/AdjointGenerator.h | 6523 ----------- enzyme/Enzyme/CApi.cpp | 2079 ---- enzyme/Enzyme/CApi.h | 237 - enzyme/Enzyme/CMakeLists.txt | 3 +- enzyme/Enzyme/CacheUtility.cpp | 1627 --- enzyme/Enzyme/CacheUtility.h | 433 - enzyme/Enzyme/CallDerivatives.cpp | 4245 ------- enzyme/Enzyme/Clang/EnzymeClang.cpp | 108 +- enzyme/Enzyme/DiffeGradientUtils.cpp | 1213 --- enzyme/Enzyme/DiffeGradientUtils.h | 129 - enzyme/Enzyme/DifferentialUseAnalysis.cpp | 1160 -- enzyme/Enzyme/DifferentialUseAnalysis.h | 547 - enzyme/Enzyme/Enzyme.cpp | 492 +- enzyme/Enzyme/EnzymeLogic.cpp | 6610 ----------- enzyme/Enzyme/EnzymeLogic.h | 781 -- enzyme/Enzyme/FunctionUtils.cpp | 8167 -------------- enzyme/Enzyme/FunctionUtils.h | 406 - enzyme/Enzyme/GradientUtils.cpp | 9704 ----------------- enzyme/Enzyme/GradientUtils.h | 661 -- enzyme/Enzyme/InstructionBatcher.cpp | 283 - enzyme/Enzyme/InstructionBatcher.h | 86 - enzyme/Enzyme/MustExitScalarEvolution.cpp | 1318 --- enzyme/Enzyme/MustExitScalarEvolution.h | 89 - enzyme/Enzyme/TraceGenerator.cpp | 424 - enzyme/Enzyme/TraceGenerator.h | 67 - enzyme/Enzyme/TraceInterface.cpp | 449 - enzyme/Enzyme/TraceInterface.h | 197 - enzyme/Enzyme/TraceUtils.cpp | 526 - enzyme/Enzyme/TraceUtils.h | 157 - enzyme/Enzyme/TypeAnalysis/BaseType.h | 78 - enzyme/Enzyme/TypeAnalysis/ConcreteType.h | 518 - enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp | 181 - enzyme/Enzyme/TypeAnalysis/RustDebugInfo.h | 38 - enzyme/Enzyme/TypeAnalysis/TBAA.h | 519 - enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 6467 ----------- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h | 415 - .../TypeAnalysis/TypeAnalysisPrinter.cpp | 192 - .../Enzyme/TypeAnalysis/TypeAnalysisPrinter.h | 52 - enzyme/Enzyme/TypeAnalysis/TypeTree.cpp | 48 - enzyme/Enzyme/TypeAnalysis/TypeTree.h | 1450 --- enzyme/Enzyme/Utils.cpp | 4044 ------- enzyme/Enzyme/Utils.h | 136 - 46 files changed, 435 insertions(+), 66469 deletions(-) delete mode 100644 enzyme/Enzyme/ActivityAnalysis.cpp delete mode 100644 enzyme/Enzyme/ActivityAnalysis.h delete mode 100644 enzyme/Enzyme/ActivityAnalysisPrinter.cpp delete mode 100644 enzyme/Enzyme/ActivityAnalysisPrinter.h delete mode 100644 enzyme/Enzyme/AdjointGenerator.h delete mode 100644 enzyme/Enzyme/CApi.cpp delete mode 100644 enzyme/Enzyme/CApi.h delete mode 100644 enzyme/Enzyme/CacheUtility.cpp delete mode 100644 enzyme/Enzyme/CacheUtility.h delete mode 100644 enzyme/Enzyme/CallDerivatives.cpp delete mode 100644 enzyme/Enzyme/DiffeGradientUtils.cpp delete mode 100644 enzyme/Enzyme/DiffeGradientUtils.h delete mode 100644 enzyme/Enzyme/DifferentialUseAnalysis.cpp delete mode 100644 enzyme/Enzyme/DifferentialUseAnalysis.h delete mode 100644 enzyme/Enzyme/EnzymeLogic.cpp delete mode 100644 enzyme/Enzyme/EnzymeLogic.h delete mode 100644 enzyme/Enzyme/FunctionUtils.cpp delete mode 100644 enzyme/Enzyme/FunctionUtils.h delete mode 100644 enzyme/Enzyme/GradientUtils.cpp delete mode 100644 enzyme/Enzyme/GradientUtils.h delete mode 100644 enzyme/Enzyme/InstructionBatcher.cpp delete mode 100644 enzyme/Enzyme/InstructionBatcher.h delete mode 100644 enzyme/Enzyme/MustExitScalarEvolution.cpp delete mode 100644 enzyme/Enzyme/MustExitScalarEvolution.h delete mode 100644 enzyme/Enzyme/TraceGenerator.cpp delete mode 100644 enzyme/Enzyme/TraceGenerator.h delete mode 100644 enzyme/Enzyme/TraceInterface.cpp delete mode 100644 enzyme/Enzyme/TraceInterface.h delete mode 100644 enzyme/Enzyme/TraceUtils.cpp delete mode 100644 enzyme/Enzyme/TraceUtils.h delete mode 100644 enzyme/Enzyme/TypeAnalysis/BaseType.h delete mode 100644 enzyme/Enzyme/TypeAnalysis/ConcreteType.h delete mode 100644 enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp delete mode 100644 enzyme/Enzyme/TypeAnalysis/RustDebugInfo.h delete mode 100644 enzyme/Enzyme/TypeAnalysis/TBAA.h delete mode 100644 enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp delete mode 100644 enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h delete mode 100644 enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp delete mode 100644 enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.h delete mode 100644 enzyme/Enzyme/TypeAnalysis/TypeTree.cpp delete mode 100644 enzyme/Enzyme/TypeAnalysis/TypeTree.h delete mode 100644 enzyme/Enzyme/Utils.cpp diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp deleted file mode 100644 index 010b79768e0b..000000000000 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ /dev/null @@ -1,3497 +0,0 @@ -//===- ActivityAnalysis.cpp - Implementation of Activity Analysis ---------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the implementation of Activity Analysis -- an AD-specific -// analysis that deduces if a given instruction or value can impact the -// calculation of a derivative. This file consists of two mutually recursive -// functions that compute this for values and instructions, respectively. -// -//===----------------------------------------------------------------------===// -#include -#include - -#include -#include - -#include "llvm/ADT/ImmutableSet.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringSet.h" - -#include "llvm/ADT/STLExtras.h" - -#include "llvm/IR/Constants.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -#include "llvm/IR/InstIterator.h" - -#include "llvm/Support/TimeProfiler.h" -#include "llvm/Support/raw_ostream.h" - -#include "llvm/IR/InlineAsm.h" - -#include "ActivityAnalysis.h" -#include "Utils.h" - -#include "llvm/Demangle/Demangle.h" - -#include "FunctionUtils.h" -#include "LibraryFuncs.h" -#include "TypeAnalysis/TBAA.h" - -#include "llvm/Analysis/ValueTracking.h" - -using namespace llvm; - -#define addAttribute addAttributeAtIndex -#define removeAttribute removeAttributeAtIndex -#define getAttribute getAttributeAtIndex -#define hasAttribute hasAttributeAtIndex - -extern "C" { -cl::opt - EnzymePrintActivity("enzyme-print-activity", cl::init(false), cl::Hidden, - cl::desc("Print activity analysis algorithm")); - -cl::opt EnzymeNonmarkedGlobalsInactive( - "enzyme-globals-default-inactive", cl::init(false), cl::Hidden, - cl::desc("Consider all nonmarked globals to be inactive")); - -cl::opt - EnzymeEmptyFnInactive("enzyme-emptyfn-inactive", cl::init(false), - cl::Hidden, - cl::desc("Empty functions are considered inactive")); - -cl::opt - EnzymeGlobalActivity("enzyme-global-activity", cl::init(false), cl::Hidden, - cl::desc("Enable correct global activity analysis")); - -cl::opt - EnzymeDisableActivityAnalysis("enzyme-disable-activity-analysis", - cl::init(false), cl::Hidden, - cl::desc("Disable activity analysis")); - -cl::opt EnzymeEnableRecursiveHypotheses( - "enzyme-enable-recursive-activity", cl::init(true), cl::Hidden, - cl::desc("Enable re-evaluation of activity analysis from updated results")); -} - -#include "llvm/IR/InstIterator.h" -#include -#include -#include - -// clang-format off -static const StringSet<> InactiveGlobals = { - "small_typeof", - "jl_small_typeof", - "ompi_request_null", - "ompi_mpi_double", - "ompi_mpi_comm_world", - "stderr", - "stdout", - "stdin", - "_ZSt3cin", - "_ZSt4cout", - "_ZNSt3__u4coutE", - "_ZNSt3__u5wcoutE", - "_ZNSt3__14coutE", - "_ZNSt3__15wcoutE", - "_ZNSt3__113basic_ostreamIcNS_11char_traitsIcEEE6sentryC1ERS3_", - "_ZSt5wcout", - "_ZSt4cerr", - "_ZNSt3__14cerrE", - "_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE", - "_ZTVSt15basic_streambufIcSt11char_traitsIcEE", - "_ZTVSt9basic_iosIcSt11char_traitsIcEE", - // istream - "_ZTVNSt7__cxx1119basic_istringstreamIcSt11char_traitsIcESaIcEEE", - "_ZTTNSt7__cxx1119basic_istringstreamIcSt11char_traitsIcESaIcEEE", - // ostream - "_ZTVNSt7__cxx1119basic_ostringstreamIcSt11char_traitsIcESaIcEEE", - "_ZTTNSt7__cxx1119basic_ostringstreamIcSt11char_traitsIcESaIcEEE", - // stringstream - "_ZTVNSt7__cxx1118basic_stringstreamIcSt11char_traitsIcESaIcEEE", - "_ZTTNSt7__cxx1118basic_stringstreamIcSt11char_traitsIcESaIcEEE", - // ifstream - "_ZTTSt14basic_ifstreamIcSt11char_traitsIcEE", - // ofstream - "_ZTTSt14basic_ofstreamIcSt11char_traitsIcEE", - // vtable for __cxxabiv1::__si_class_type_info - "_ZTVN10__cxxabiv120__si_class_type_infoE", - "_ZTVN10__cxxabiv117__class_type_infoE", - "_ZTVN10__cxxabiv121__vmi_class_type_infoE" -}; - -const llvm::StringMap MPIInactiveCommAllocators = { - {"MPI_Graph_create", 5}, - {"MPI_Comm_split", 2}, - {"MPI_Intercomm_create", 6}, - {"MPI_Comm_spawn", 6}, - {"MPI_Comm_spawn_multiple", 7}, - {"MPI_Comm_accept", 4}, - {"MPI_Comm_connect", 4}, - {"MPI_Comm_create", 2}, - {"MPI_Comm_create_group", 3}, - {"MPI_Comm_dup", 1}, - {"MPI_Comm_dup", 2}, - {"MPI_Comm_idup", 1}, - {"MPI_Comm_join", 1}, -}; -// clang-format on - -/// Return whether the call is always inactive by definition. -bool isInactiveCall(CallBase &CI) { - - // clang-format off -const char *KnownInactiveFunctionsStartingWith[] = { - "f90io", - "$ss5print", - "strcpy", - "_ZTv0_n24_NSoD", //"1Ev, 0Ev - "_ZNSt16allocator_traitsISaIdEE10deallocate", - "_ZNSaIcED1Ev", - "_ZNSaIcEC1Ev", -}; - -const char *KnownInactiveFunctionsContains[] = { - "__enzyme_float", "__enzyme_double", "__enzyme_integer", - "__enzyme_pointer"}; - -const StringSet<> KnownInactiveFunctions = { - "mpfr_greater_p", - "__nv_isnand", - "__nv_isnanf", - "__nv_isinfd", - "__nv_isinff", - "__nv_isfinitel", - "__nv_isfinited", - "cublasCreate_v2", - "cublasSetMathMode", - "cublasSetStream_v2", - "cuMemPoolTrimTo", - "cuDeviceGetMemPool", - "cuStreamCreate", - "cuStreamSynchronize", - "cuStreamDestroy", - "cuStreamQuery", - "cuCtxGetCurrent", - "enzyme_zerotype", - "abort", - "time", - "memcmp", - "memchr", - "gettimeofday", - "stat", - "mkdir", - "compress2", - "__assert_fail", - "__cxa_atexit", - "__cxa_guard_acquire", - "__cxa_guard_release", - "__cxa_guard_abort", - "getenv", - "strtol", - "fwrite", - "snprintf", - "sprintf", - "printf", - "putchar", - "fprintf", - "vprintf", - "vsnprintf", - "puts", - "fputc", - "fflush", - "__kmpc_for_static_init_4", - "__kmpc_for_static_init_4u", - "__kmpc_for_static_init_8", - "__kmpc_for_static_init_8u", - "__kmpc_for_static_fini", - "__kmpc_dispatch_init_4", - "__kmpc_dispatch_init_4u", - "__kmpc_dispatch_init_8", - "__kmpc_dispatch_init_8u", - "__kmpc_dispatch_next_4", - "__kmpc_dispatch_next_4u", - "__kmpc_dispatch_next_8", - "__kmpc_dispatch_next_8u", - "__kmpc_dispatch_fini_4", - "__kmpc_dispatch_fini_4u", - "__kmpc_dispatch_fini_8", - "__kmpc_dispatch_fini_8u", - "__kmpc_barrier", - "__kmpc_barrier_master", - "__kmpc_barrier_master_nowait", - "__kmpc_barrier_end_barrier_master", - "__kmpc_global_thread_num", - "omp_get_max_threads", - "malloc_usable_size", - "malloc_size", - "MPI_Init", - "MPI_Comm_size", - "PMPI_Comm_size", - "MPI_Comm_rank", - "PMPI_Comm_rank", - "MPI_Get_processor_name", - "MPI_Finalize", - "MPI_Test", - "MPI_Probe", // double check potential syncronization - "MPI_Barrier", - "MPI_Abort", - "MPI_Get_count", - "MPI_Comm_free", - "MPI_Comm_get_parent", - "MPI_Comm_get_name", - "MPI_Comm_get_info", - "MPI_Comm_remote_size", - "MPI_Comm_set_info", - "MPI_Comm_set_name", - "MPI_Comm_compare", - "MPI_Comm_call_errhandler", - "MPI_Comm_create_errhandler", - "MPI_Comm_disconnect", - "MPI_Wtime", - "_msize", - "ftnio_fmt_write64", - "f90_strcmp_klen", - "__swift_instantiateConcreteTypeFromMangledName", - "logb", - "logbf", - "logbl", - "cuCtxGetCurrent", - "cuDeviceGet", - "cuDeviceGetName", - "cuDriverGetVersion", - "cudaRuntimeGetVersion", - "cuDeviceGetCount", - "cuMemPoolGetAttribute", - "cuMemGetInfo_v2", - "cuDeviceGetAttribute", - "cuDevicePrimaryCtxRetain", - "floor", - "floorf", - "floorl", - "\01_fopen", - "fopen", - "fclose", -}; - -const std::set KnownInactiveIntrinsics = { - Intrinsic::experimental_noalias_scope_decl, - Intrinsic::objectsize, - Intrinsic::floor, - Intrinsic::ceil, - Intrinsic::trunc, - Intrinsic::rint, - Intrinsic::lrint, - Intrinsic::llrint, - Intrinsic::nearbyint, - Intrinsic::round, - Intrinsic::roundeven, - Intrinsic::lround, - Intrinsic::llround, - Intrinsic::nvvm_barrier0, - Intrinsic::nvvm_barrier0_popc, - Intrinsic::nvvm_barrier0_and, - Intrinsic::nvvm_barrier0_or, - Intrinsic::nvvm_membar_cta, - Intrinsic::nvvm_membar_gl, - Intrinsic::nvvm_membar_sys, - Intrinsic::amdgcn_s_barrier, - Intrinsic::assume, - Intrinsic::stacksave, - Intrinsic::stackrestore, - Intrinsic::lifetime_start, - Intrinsic::lifetime_end, -#if LLVM_VERSION_MAJOR <= 16 - Intrinsic::dbg_addr, -#endif - - Intrinsic::dbg_declare, - Intrinsic::dbg_value, - Intrinsic::dbg_label, - Intrinsic::invariant_start, - Intrinsic::invariant_end, - Intrinsic::var_annotation, - Intrinsic::ptr_annotation, - Intrinsic::annotation, - Intrinsic::codeview_annotation, - Intrinsic::expect, - Intrinsic::type_test, - Intrinsic::donothing, - Intrinsic::prefetch, - Intrinsic::trap, - Intrinsic::is_constant, - Intrinsic::memset}; - -const char *DemangledKnownInactiveFunctionsStartingWith[] = { - // TODO this returns allocated memory and thus can be an active value - // "std::allocator" - "std::__u::basic_streambuf", - "std::__u::basic_iostream", - "std::__u::basic_ios", - "std::__u::basic_istream", - "std::__u::basic_string", - "std::__u::basic_filebuf", - "std::__u::locale", - "std::__u::ios_base", - "std::__u::basic_ostream", - "absl::log_internal::LogMessage", - "std::chrono::_V2::steady_clock::now", - "std::string", - "std::cerr", - "std::istream", - "std::ostream", - "std::ios_base", - "std::locale", - "std::ctype", - "std::__basic_file", - "std::__ioinit", - "std::__basic_file", - "std::hash", - "std::_Hash_bytes", - - // __cxx11 - "std::__cxx11::basic_string", - "std::__cxx11::basic_ios", - "std::__cxx11::basic_ostringstream", - "std::__cxx11::basic_istringstream", - "std::__cxx11::basic_istream", - "std::__cxx11::basic_ostream", - "std::__cxx11::basic_ifstream", - "std::__cxx11::basic_ofstream", - "std::__cxx11::basic_stringbuf", - "std::__cxx11::basic_filebuf", - "std::__cxx11::basic_streambuf", - - // non __cxx11 - "std::basic_string", - "std::to_string", - "std::basic_ios", - "std::basic_ostringstream", - "std::basic_istringstream", - "std::basic_istream", - "std::basic_ostream", - "std::basic_ifstream", - "std::basic_ofstream", - "std::basic_stringbuf", - "std::basic_filebuf", - "std::basic_streambuf", - "std::random_device", - "std::mersenne_twister_engine", - "std::linear_congruential_engine", - "std::subtract_with_carry_engine", - "std::discard_block_engine", - "std::independent_bits_engine", - "std::shuffle_order_engine", - - - // libc++ - "std::__1::locale", - "std::__1::ios_base", - "std::__1::basic_string", - "std::__1::__do_string_hash", - "std::__1::hash", - "std::__1::__unordered_map_hasher", - "std::__1::to_string", - "std::__1::basic_ostream", - "std::__1::cout", - "std::__1::random_device", - "std::__1::mersenne_twister_engine", - "std::__1::linear_congruential_engine", - "std::__1::subtract_with_carry_engine", - "std::__1::discard_block_engine", - "std::__1::independent_bits_engine", - "std::__1::shuffle_order_engine", - "std::__1::basic_streambuf", - "std::__1::basic_stringbuf", - "std::__1::basic_istream", - "std::__1::basic_filebuf", - "std::__1::basic_iostream", - "std::__1::basic_ios", - "virtual thunk to std::__1::basic_istream", - "virtual thunk to std::__1::basic_ostream", - - "std::__detail::_Prime_rehash_policy", - "std::__detail::_Hash_code_base", - - // Rust - "std::io::stdio::_eprint", - -}; - // clang-format on - - if (CI.hasFnAttr("enzyme_inactive")) - return true; - - if (auto iasm = dyn_cast(CI.getCalledOperand())) { - if (StringRef(iasm->getAsmString()).contains("exit") || - StringRef(iasm->getAsmString()).contains("cpuid")) - return true; - } - - if (auto F = getFunctionFromCall(&CI)) { - if (F->hasFnAttribute("enzyme_inactive")) { - return true; - } - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return true; - } - } - - auto Name = getFuncNameFromCall(&CI); - - std::string demangledName = llvm::demangle(Name.str()); - auto dName = StringRef(demangledName); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(Name, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (Name.contains(FuncName)) { - return true; - } - } - if (KnownInactiveFunctions.count(Name)) { - return true; - } - - if (MPIInactiveCommAllocators.find(Name) != MPIInactiveCommAllocators.end()) { - return true; - } - Intrinsic::ID ID; - if (isMemFreeLibMFunction(Name, &ID)) - if (KnownInactiveIntrinsics.count(ID)) { - return true; - } - - // Copies of size 1 are inactive [cannot move differentiable data in one byte] - if (auto MTI = dyn_cast(&CI)) { - if (auto sz = dyn_cast(MTI->getOperand(2))) { - if (sz->getValue() == 1) - return true; - } - } - - return false; -} - -bool isInactiveCallInst(CallBase &CB, llvm::TargetLibraryInfo &TLI) { - // clang-format off -// Instructions which themselves are inactive -// the returned value, however, may still be active -static const StringSet<> KnownInactiveFunctionInsts = { - "__dynamic_cast", - "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base", - "jl_ptr_to_array", - "jl_ptr_to_array_1d"}; - // clang-format on - if (isInactiveCall(CB)) { - return true; - } - if (CB.hasFnAttr("enzyme_inactive_inst")) { - return true; - } - auto called = getFunctionFromCall(&CB); - - if (called) { - if (called->hasFnAttribute("enzyme_inactive_inst")) { - return true; - } - } - - auto funcName = getFuncNameFromCall(&CB); - if (KnownInactiveFunctionInsts.count(funcName)) { - return true; - } - - if (isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - return true; - } - - return false; -} - -/// Is the use of value val as an argument of call CI known to be inactive -/// This tool can only be used when in DOWN mode -bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { - assert(directions & DOWN); - if (isInactiveCall(*CI)) - return true; - - auto F = getFunctionFromCall(CI); - - bool all_inactive = val != CI->getCalledOperand(); - - for (size_t i = 0; i < CI->arg_size(); i++) { - if (val == CI->getArgOperand(i)) { - if (!CI->getAttributes().hasParamAttr(i, "enzyme_inactive") && - !(F && F->getCallingConv() == CI->getCallingConv() && - F->getAttributes().hasParamAttr(i, "enzyme_inactive"))) { - all_inactive = false; - break; - } - } - } - - if (all_inactive) - return true; - - // Indirect function calls may actively use the argument - if (F == nullptr) - return false; - - auto Name = getFuncNameFromCall(CI); - - // Only the 1-th arg impacts activity - if (Name == "jl_reshape_array" || Name == "ijl_reshape_array") - return val != CI->getArgOperand(1); - - // Only the 0-th arg impacts activity - if (Name == "jl_genericmemory_copy_slice" || - Name == "ijl_genericmemory_copy_slice") - return val != CI->getArgOperand(0); - - // Allocations, deallocations, and c++ guards don't impact the activity - // of arguments - if (isAllocationFunction(Name, TLI) || isDeallocationFunction(Name, TLI)) - return true; - - /// Only the first argument (magnitude) of copysign is active - if (F->getIntrinsicID() == Intrinsic::copysign && - CI->getArgOperand(0) != val) { - return true; - } - - if (F->getIntrinsicID() == Intrinsic::memcpy && CI->getArgOperand(0) != val && - CI->getArgOperand(1) != val) - return true; - if (F->getIntrinsicID() == Intrinsic::memmove && - CI->getArgOperand(0) != val && CI->getArgOperand(1) != val) - return true; - - // only the buffer is active for mpi send/recv - if (Name == "MPI_Recv" || Name == "PMPI_Recv" || Name == "MPI_Send" || - Name == "PMPI_Send") { - return val != CI->getOperand(0); - } - // only the recv buffer and request is active for mpi isend/irecv - if (Name == "MPI_Irecv" || Name == "MPI_Isend") { - return val != CI->getOperand(0) && val != CI->getOperand(6); - } - - // only request is active - if (Name == "MPI_Wait" || Name == "PMPI_Wait") - return val != CI->getOperand(0); - - if (Name == "MPI_Waitall" || Name == "PMPI_Waitall") - return val != CI->getOperand(1); - - if (Name == "julia.gc_loaded") - return val != CI->getOperand(1); - - // TODO interprocedural detection - // Before potential introprocedural detection, any function without definition - // may to be assumed to have an active use - if (F->empty()) - return false; - - // With all other options exhausted we have to assume this function could - // actively use the value - return false; -} - -/// Call the function propagateFromOperand on all operands of CI -/// that could impact the activity of the call instruction -static inline void propagateArgumentInformation( - TargetLibraryInfo &TLI, CallInst &CI, - llvm::function_ref propagateFromOperand) { - if (isInactiveCall(CI)) - return; - - // These functions are known to only have the first argument impact - // the activity of the call instruction - auto Name = getFuncNameFromCall(&CI); - if (Name == "lgamma" || Name == "lgammaf" || Name == "lgammal" || - Name == "lgamma_r" || Name == "lgammaf_r" || Name == "lgammal_r" || - Name == "__lgamma_r_finite" || Name == "__lgammaf_r_finite" || - Name == "__lgammal_r_finite") { - - propagateFromOperand(CI.getArgOperand(0)); - return; - } - - // Only the 1-st arg impacts activity - if (Name == "julia.gc_loaded") { - propagateFromOperand(CI.getArgOperand(1)); - return; - } - - if (Name == "julia.call" || Name == "julia.call2") { - for (size_t i = 1; i < CI.arg_size(); i++) { - propagateFromOperand(CI.getOperand(i)); - } - return; - } - - // Only the 0-th arg impacts activity - if (Name == "jl_genericmemory_copy_slice" || - Name == "ijl_genericmemory_copy_slice") { - propagateFromOperand(CI.getArgOperand(0)); - return; - } - - // Only the 1-th arg impacts activity - if (Name == "jl_reshape_array" || Name == "ijl_reshape_array") { - propagateFromOperand(CI.getArgOperand(1)); - return; - } - - // Allocations, deallocations, and c++ guards are fully inactive - if (isAllocationFunction(Name, TLI) || isDeallocationFunction(Name, TLI) || - Name == "__cxa_guard_acquire" || Name == "__cxa_guard_release" || - Name == "__cxa_guard_abort") - return; - - auto F = getFunctionFromCall(&CI); - - if (F) { - - /// Only the first argument (magnitude) of copysign is active - if (F->getIntrinsicID() == Intrinsic::copysign) { - propagateFromOperand(CI.getOperand(0)); - return; - } - - if (F->getIntrinsicID() == Intrinsic::memcpy || - F->getIntrinsicID() == Intrinsic::memmove) { - propagateFromOperand(CI.getOperand(0)); - propagateFromOperand(CI.getOperand(1)); - return; - } - } - - // For other calls, check all operands of the instruction - // as conservatively they may impact the activity of the call - size_t i = 0; - for (auto &a : CI.args()) { - - if (CI.getAttributes().hasParamAttr(i, "enzyme_inactive") || - (F && F->getCallingConv() == CI.getCallingConv() && - F->getAttributes().hasParamAttr(i, "enzyme_inactive"))) { - i++; - continue; - } - - if (propagateFromOperand(a)) - break; - i++; - } -} - -/// Return whether this instruction is known not to propagate adjoints -/// Note that instructions could return an active pointer, but -/// do not propagate adjoints themselves -bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, - Instruction *I) { - - TimeTraceScope timeScope("isConstantInstruction", - I->getParent()->getParent()->getName()); - - // This analysis may only be called by instructions corresponding to - // the function analyzed by TypeInfo - assert(I); - assert(TR.getFunction() == I->getParent()->getParent()); - - // The return instruction doesn't impact activity (handled specifically - // during adjoint generation) - if (isa(I)) - return true; - - // Branch, unreachable, and previously computed constants are inactive - if (isa(I) || isa(I) || - (ConstantInstructions.find(I) != ConstantInstructions.end())) { - return true; - } - - /// Previously computed inactives remain inactive - if ((ActiveInstructions.find(I) != ActiveInstructions.end())) { - return false; - } - - /// Overwrite activity using metadata - if (hasMetadata(I, "enzyme_active") || hasMetadata(I, "enzyme_active_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "[activity] forced instruction to be active: " << *I - << "\n"; - return false; - } else if (hasMetadata(I, "enzyme_inactive") || - hasMetadata(I, "enzyme_inactive_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "[activity] forced value to be constant: " << *I << "\n"; - return true; - } - - if (notForAnalysis.count(I->getParent())) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction as dominates unreachable " << *I - << "\n"; - InsertConstantInstruction(TR, I); - return true; - } - - if (isa(I)) { - if (EnzymePrintActivity) - llvm::errs() << " constant fence instruction " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } - - if (auto CI = dyn_cast(I)) { - if (CI->hasFnAttr("enzyme_active") || CI->hasFnAttr("enzyme_active_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "forced active " << *I << "\n"; - ActiveInstructions.insert(I); - return false; - } - auto called = getFunctionFromCall(CI); - - if (called) { - if (called->hasFnAttribute("enzyme_active") || - called->hasFnAttribute("enzyme_active_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "forced active " << *I << "\n"; - ActiveInstructions.insert(I); - return false; - } - } - if (isInactiveCallInst(*CI, TLI)) { - if (EnzymePrintActivity) - llvm::errs() << "known inactive instruction from call " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } - } - - if (auto II = dyn_cast(I)) { - if (isIntelSubscriptIntrinsic(*II)) { - // The intrinsic "llvm.intel.subscript" does not propogate deriviative - // information directly. But its returned pointer may be active. - InsertConstantInstruction(TR, I); - return true; - } - } - - if (EnzymeDisableActivityAnalysis) - return false; - - /// A store into all integral memory is inactive - if (auto SI = dyn_cast(I)) { - auto StoreSize = SI->getParent() - ->getParent() - ->getParent() - ->getDataLayout() - .getTypeSizeInBits(SI->getValueOperand()->getType()) / - 8; - - bool AllIntegral = true; - bool SeenInteger = false; - auto q = TR.query(SI->getPointerOperand()).Data0(); - for (int i = -1; i < (int)StoreSize; ++i) { - auto dt = q[{i}]; - if (dt.isIntegral() || dt == BaseType::Anything) { - SeenInteger = true; - if (i == -1) - break; - } else if (dt.isKnown()) { - AllIntegral = false; - break; - } - } - - if (AllIntegral && SeenInteger) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction from TA " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } - } - if (auto SI = dyn_cast(I)) { - auto StoreSize = SI->getParent() - ->getParent() - ->getParent() - ->getDataLayout() - .getTypeSizeInBits(I->getType()) / - 8; - - bool AllIntegral = true; - bool SeenInteger = false; - auto q = TR.query(SI->getOperand(0)).Data0(); - for (int i = -1; i < (int)StoreSize; ++i) { - auto dt = q[{i}]; - if (dt.isIntegral() || dt == BaseType::Anything) { - SeenInteger = true; - if (i == -1) - break; - } else if (dt.isKnown()) { - AllIntegral = false; - break; - } - } - - if (AllIntegral && SeenInteger) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction from TA " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } - } - - if (EnzymePrintActivity) - llvm::errs() << "checking if is constant[" << (int)directions << "] " << *I - << "\n"; - - // Analyzer for inductive assumption where we attempt to prove this is - // inactive from a lack of active users - std::unique_ptr DownHypothesis; - - // If this instruction does not write to memory that outlives itself - // (potentially propagating derivative information), the only way to propagate - // derivative information is through the return value - // TODO the "doesn't write to active memory" can be made more aggressive than - // doesn't write to any memory - bool noActiveWrite = false; - if (!I->mayWriteToMemory()) - noActiveWrite = true; - else if (auto CI = dyn_cast(I)) { - if (AA.onlyReadsMemory(CI) || isReadOnly(CI)) { - noActiveWrite = true; - } else { - StringRef funcName = getFuncNameFromCall(CI); - if (isMemFreeLibMFunction(funcName)) { - noActiveWrite = true; - } else if (funcName == "frexp" || funcName == "frexpf" || - funcName == "frexpl" || funcName == "modf" || - funcName == "modff" || funcName == "modfl") { - noActiveWrite = true; - } - } - } - if (noActiveWrite) { - bool possibleFloat = TR.anyFloat(I); - // Even if returning a pointer, this instruction is considered inactive - // since the instruction doesn't prop gradients. Thus, so long as we don't - // return an object containing a float, this instruction is inactive - if (!possibleFloat) { - if (EnzymePrintActivity) - llvm::errs() - << " constant instruction from known non-float non-writing " - "instruction " - << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } - - // If the value returned is constant otherwise, the instruction is inactive - if (isConstantValue(TR, I)) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction from known constant non-writing " - "instruction " - << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } - - // Even if the return is nonconstant, it's worth checking explicitly the - // users since unlike isConstantValue, returning a pointer does not make the - // instruction active - if (directions & DOWN) { - // We shall now induct on this instruction being inactive and try to prove - // this fact from a lack of active users. - - // If we aren't a phi node (and thus potentially recursive on uses) and - // already equal to the current direction, we don't need to induct, - // reducing runtime. - if (directions == DOWN && !isa(I)) { - if (isValueInactiveFromUsers(TR, I, UseActivity::None)) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction[" << (int)directions - << "] from users instruction " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } - } else { - DownHypothesis = std::unique_ptr( - new ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantInstructions.insert(I); - if (DownHypothesis->isValueInactiveFromUsers(TR, I, - UseActivity::None)) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction[" << (int)directions - << "] from users instruction " << *I << "\n"; - InsertConstantInstruction(TR, I); - insertConstantsFrom(TR, *DownHypothesis); - return true; - } - } - } - } - - std::unique_ptr UpHypothesis; - if (directions & UP) { - // If this instruction has no active operands, the instruction - // is active. - // TODO This isn't 100% accurate and will incorrectly mark a no-argument - // function that reads from active memory as constant - // Technically the additional constraint is that this does not read from - // active memory, where we have assumed that the only active memory - // we care about is accessible from arguments passed (and thus not globals) - UpHypothesis = - std::unique_ptr(new ActivityAnalyzer(*this, UP)); - UpHypothesis->ConstantInstructions.insert(I); - assert(directions & UP); - if (UpHypothesis->isInstructionInactiveFromOrigin(TR, I, false)) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction from origin " - "instruction " - << *I << "\n"; - InsertConstantInstruction(TR, I); - insertConstantsFrom(TR, *UpHypothesis); - if (DownHypothesis) - insertConstantsFrom(TR, *DownHypothesis); - return true; - } else if (directions == 3) { - for (auto &op : I->operands()) { - if (!UpHypothesis->isConstantValue(TR, op) && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateInstIfInactiveValue[op].insert(I); - } - } - } - } - - // Otherwise we must fall back and assume this instruction to be active. - ActiveInstructions.insert(I); - if (EnzymePrintActivity) - llvm::errs() << "couldnt decide fallback as nonconstant instruction(" - << (int)directions << "):" << *I << "\n"; - if (noActiveWrite && directions == 3 && EnzymeEnableRecursiveHypotheses) - ReEvaluateInstIfInactiveValue[I].insert(I); - return false; -} - -bool isValuePotentiallyUsedAsPointer(llvm::Value *val) { - std::deque todo = {val}; - SmallPtrSet seen; - while (todo.size()) { - auto cur = todo.back(); - todo.pop_back(); - if (seen.count(cur)) - continue; - seen.insert(cur); - for (auto u : cur->users()) { - if (isa(u)) - return true; - if (!cast(u)->mayReadOrWriteMemory()) { - todo.push_back(u); - continue; - } - if (EnzymePrintActivity) - llvm::errs() << " VALUE potentially used as pointer " << *val << " by " - << *u << "\n"; - return true; - } - } - return false; -} - -bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { - // This analysis may only be called by instructions corresponding to - // the function analyzed by TypeInfo -- however if the Value - // was created outside a function (e.g. global, constant), that is allowed - TimeTraceScope timeScope("isConstantValue"); - - assert(Val); - if (auto I = dyn_cast(Val)) { - if (TR.getFunction() != I->getParent()->getParent()) { - llvm::errs() << *TR.getFunction() << "\n"; - llvm::errs() << *I << "\n"; - } - assert(TR.getFunction() == I->getParent()->getParent()); - } -#ifndef NDEBUG - if (auto Arg = dyn_cast(Val)) { - assert(TR.getFunction() == Arg->getParent()); - } -#endif - - // Void values are definitionally inactive - if (Val->getType()->isVoidTy()) - return true; - - // Token values are definitionally inactive - if (Val->getType()->isTokenTy()) - return true; - - // All function pointers are considered active in case an augmented primal - // or reverse is needed - if (isa(Val) || isa(Val)) { - return false; - } - - /// If we've already shown this value to be inactive - if (ConstantValues.find(Val) != ConstantValues.end()) { - return true; - } - - /// If we've already shown this value to be active - if (ActiveValues.find(Val) != ActiveValues.end()) { - return false; - } - - // We do this check down here so we can go past asserted constant values from - // arguments, and also allow void/tokens to be inactive. - if (!EnzymeDisableActivityAnalysis) { - - if (auto CD = dyn_cast(Val)) { - // inductively assume inactive - ConstantValues.insert(CD); - for (size_t i = 0, len = CD->getNumElements(); i < len; i++) { - if (!isConstantValue(TR, CD->getElementAsConstant(i))) { - ConstantValues.erase(CD); - ActiveValues.insert(CD); - return false; - } - } - return true; - } - if (auto CD = dyn_cast(Val)) { - // inductively assume inactive - ConstantValues.insert(CD); - for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - if (!isConstantValue(TR, CD->getOperand(i))) { - ConstantValues.erase(CD); - ActiveValues.insert(CD); - return false; - } - } - return true; - } - - // Undef, metadata, non-global constants, and blocks are inactive - if (isa(Val) || isa(Val) || - isa(Val) || isa(Val) || - isa(Val)) { - return true; - } - - // All arguments must be marked constant/nonconstant ahead of time - if (isa(Val) && !cast(Val)->hasByValAttr()) { - llvm::errs() << *(cast(Val)->getParent()) << "\n"; - llvm::errs() << *Val << "\n"; - assert(0 && "must've put arguments in constant/nonconstant"); - } - - // This value is certainly an integer (and only and integer, not a pointer - // or float). Therefore its value is constant - if (TR.query(Val)[{-1}] == BaseType::Integer) { - if (EnzymePrintActivity) - llvm::errs() << " Value const as integral " << (int)directions << " " - << *Val << " " - << TR.intType(1, Val, /*errIfNotFound*/ false).str() - << "\n"; - InsertConstantValue(TR, Val); - return true; - } - - // Overwrite activity using metadata - if (auto *I = dyn_cast(Val)) { - if (hasMetadata(I, "enzyme_active") || - hasMetadata(I, "enzyme_active_val")) { - if (EnzymePrintActivity) - llvm::errs() << "[activity] forced value to be active: " << *Val - << "\n"; - return false; - } else if (hasMetadata(I, "enzyme_inactive") || - hasMetadata(I, "enzyme_inactive_val")) { - if (EnzymePrintActivity) - llvm::errs() << "[activity] forced value to be constant: " << *Val - << "\n"; - return true; - } - } - - // Overwrite activity using metadata - if (auto *I = dyn_cast(Val)) { - if (hasMetadata(I, "enzyme_active") || - hasMetadata(I, "enzyme_active_val")) { - if (EnzymePrintActivity) - llvm::errs() << "[activity] forced value to be active: " << *Val - << "\n"; - return false; - } else if (hasMetadata(I, "enzyme_inactive") || - hasMetadata(I, "enzyme_inactive_val")) { - if (EnzymePrintActivity) - llvm::errs() << "[activity] forced value to be constant: " << *Val - << "\n"; - return true; - } - } - -#if 0 - // This value is certainly a pointer to an integer (and only and integer, not - // a pointer or float). Therefore its value is constant - // TODO use typeInfo for more aggressive activity analysis - if (val->getType()->isPointerTy() && - cast(val->getType())->isIntOrIntVectorTy() && - TR.firstPointer(1, val, /*errifnotfound*/ false).isIntegral()) { - if (EnzymePrintActivity) - llvm::errs() << " Value const as integral pointer" << (int)directions - << " " << *val << "\n"; - InsertConstantValue(TR, Val); - return true; - } -#endif - - if (auto GA = dyn_cast(Val)) - return isConstantValue(TR, GA->getAliasee()); - - if (auto GI = dyn_cast(Val)) { - // If operating under the assumption globals are inactive unless - // explicitly marked as active, this is inactive - if (!hasMetadata(GI, "enzyme_shadow") && EnzymeNonmarkedGlobalsInactive) { - InsertConstantValue(TR, Val); - return true; - } - if (hasMetadata(GI, "enzyme_inactive")) { - InsertConstantValue(TR, Val); - return true; - } - - if (GI->getName().contains("enzyme_const") || - InactiveGlobals.count(GI->getName())) { - InsertConstantValue(TR, Val); - return true; - } - - // If this global is unchanging and the internal constant data - // is inactive, the global is inactive - if (GI->isConstant() && GI->hasInitializer() && - isConstantValue(TR, GI->getInitializer())) { - InsertConstantValue(TR, Val); - if (EnzymePrintActivity) - llvm::errs() << " VALUE const global " << *Val - << " init: " << *GI->getInitializer() << "\n"; - return true; - } - - // If this global is a pointer to an integer, it is inactive - // TODO note this may need updating to consider the size - // of the global - auto res = TR.query(GI).Data0(); - auto dt = res[{-1}]; - if (dt.isIntegral()) { - if (EnzymePrintActivity) - llvm::errs() << " VALUE const as global int pointer " << *Val - << " type - " << res.str() << "\n"; - InsertConstantValue(TR, Val); - return true; - } - - // If this is a global local to this translation unit with inactive - // initializer and no active uses, it is definitionally inactive - bool usedJustInThisModule = - GI->hasInternalLinkage() || GI->hasPrivateLinkage(); - - if (EnzymePrintActivity) - llvm::errs() << "pre attempting(" << (int)directions - << ") just used in module for: " << *GI << " dir" - << (int)directions - << " justusedin:" << usedJustInThisModule << "\n"; - - if (directions == 3 && usedJustInThisModule) { - // TODO this assumes global initializer cannot refer to itself (lest - // infinite loop) - if (!GI->hasInitializer() || - isConstantValue(TR, GI->getInitializer())) { - - if (EnzymePrintActivity) - llvm::errs() << "attempting just used in module for: " << *GI - << "\n"; - // Not looking at users to prove inactive (definition of down) - // If all users are inactive, this is therefore inactive. - // Since we won't look at origins to prove, we can inductively assume - // this is inactive - - // As an optimization if we are going down already - // and we won't use ourselves (done by PHI's), we - // dont need to inductively assume we're true - // and can instead use this object! - // This pointer is inactive if it is either not actively stored to or - // not actively loaded from - // See alloca logic to explain why OnlyStores is insufficient here - if (directions == DOWN) { - if (isValueInactiveFromUsers(TR, Val, UseActivity::OnlyLoads)) { - InsertConstantValue(TR, Val); - return true; - } - } else { - Instruction *LoadReval = nullptr; - Instruction *StoreReval = nullptr; - auto DownHypothesis = std::unique_ptr( - new ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantValues.insert(Val); - if (DownHypothesis->isValueInactiveFromUsers( - TR, Val, UseActivity::OnlyLoads, &LoadReval) || - (TR.query(GI)[{-1, -1}].isFloat() && - DownHypothesis->isValueInactiveFromUsers( - TR, Val, UseActivity::OnlyStores, &StoreReval))) { - insertConstantsFrom(TR, *DownHypothesis); - InsertConstantValue(TR, Val); - return true; - } else { - if (LoadReval && EnzymeEnableRecursiveHypotheses) { - if (EnzymePrintActivity) - llvm::errs() << " global activity of " << *Val - << " dependant on " << *LoadReval << "\n"; - ReEvaluateValueIfInactiveInst[LoadReval].insert(Val); - } - if (StoreReval && EnzymeEnableRecursiveHypotheses) - ReEvaluateValueIfInactiveInst[StoreReval].insert(Val); - } - } - } - } - - // Otherwise we have to assume this global is active since it can - // be arbitrarily used in an active way - // TODO we can be more aggressive here in the future - if (EnzymePrintActivity) - llvm::errs() << " VALUE nonconst unknown global " << *Val << " type - " - << res.str() << "\n"; - ActiveValues.insert(Val); - return false; - } - - // ConstantExpr's are inactive if their arguments are inactive - // Note that since there can't be a recursive constant this shouldn't - // infinite loop - if (auto ce = dyn_cast(Val)) { - if (ce->isCast()) { - if (isConstantValue(TR, ce->getOperand(0))) { - if (EnzymePrintActivity) - llvm::errs() << " VALUE const cast from from operand " << *Val - << "\n"; - InsertConstantValue(TR, Val); - return true; - } - } - if (ce->getOpcode() == Instruction::GetElementPtr && - llvm::all_of(ce->operand_values(), - [&](Value *v) { return isConstantValue(TR, v); })) { - if (isConstantValue(TR, ce->getOperand(0))) { - if (EnzymePrintActivity) - llvm::errs() << " VALUE const cast from gep operand " << *Val - << "\n"; - InsertConstantValue(TR, Val); - return true; - } - } - if (EnzymePrintActivity) - llvm::errs() << " VALUE nonconst unknown expr " << *Val << "\n"; - ActiveValues.insert(Val); - return false; - } - - if (auto I = dyn_cast(Val)) { - if (hasMetadata(I, "enzyme_active") || - hasMetadata(I, "enzyme_active_val")) { - if (EnzymePrintActivity) - llvm::errs() << "forced active val (MD)" << *Val << "\n"; - InsertConstantValue(TR, Val); - return true; - } - if (hasMetadata(I, "enzyme_inactive") || - hasMetadata(I, "enzyme_inactive_val")) { - if (EnzymePrintActivity) - llvm::errs() << "forced inactive val (MD)" << *Val << "\n"; - InsertConstantValue(TR, Val); - return true; - } - } - if (auto CI = dyn_cast(Val)) { - if (CI->hasFnAttr("enzyme_active") || - CI->hasFnAttr("enzyme_active_val") || - CI->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex, - "enzyme_active")) { - if (EnzymePrintActivity) - llvm::errs() << "forced active val " << *Val << "\n"; - ActiveValues.insert(Val); - return false; - } - if (CI->hasFnAttr("enzyme_inactive") || - CI->hasFnAttr("enzyme_inactive_val") || - CI->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex, - "enzyme_inactive")) { - if (EnzymePrintActivity) - llvm::errs() << "forced inactive val " << *Val << "\n"; - InsertConstantValue(TR, Val); - return true; - } - auto called = getFunctionFromCall(CI); - - if (called) { - if (called->hasFnAttribute("enzyme_active") || - called->hasFnAttribute("enzyme_active_val") || - called->getAttributes().hasAttribute( - llvm::AttributeList::ReturnIndex, "enzyme_active")) { - if (EnzymePrintActivity) - llvm::errs() << "forced active val " << *Val << "\n"; - ActiveValues.insert(Val); - return false; - } - if (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_val") || - called->getAttributes().hasAttribute( - llvm::AttributeList::ReturnIndex, "enzyme_inactive")) { - if (EnzymePrintActivity) - llvm::errs() << "forced inactive val " << *Val << "\n"; - InsertConstantValue(TR, Val); - return true; - } - } - if (isInactiveCall(*CI)) { - if (EnzymePrintActivity) - llvm::errs() << "known inactive val from call" << *Val << "\n"; - InsertConstantValue(TR, Val); - return true; - } - } - if (auto BO = dyn_cast(Val)) { - // x & 0b100000 is definitionally inactive - // + if floating point, this returns either +/- 0 - // if int/pointer, this contains no info - if (BO->getOpcode() == Instruction::And) { - auto &DL = BO->getParent()->getParent()->getParent()->getDataLayout(); - for (int i = 0; i < 2; ++i) { - auto FT = TR.query(BO->getOperand(1 - i)) - .IsAllFloat( - (DL.getTypeSizeInBits(BO->getType()) + 7) / 8, DL); - // If ^ against 0b10000000000 and a float the result is a float - if (FT) - if (containsOnlyAtMostTopBit(BO->getOperand(i), FT, DL)) { - if (EnzymePrintActivity) - llvm::errs() << " inactive bithack " << *Val << "\n"; - InsertConstantValue(TR, Val); - return true; - } - } - } - } - } - - std::shared_ptr UpHypothesis; - - // Handle types that could contain pointers - // Consider all types except - // * floating point types (since those are assumed not pointers) - // * integers that we know are not pointers - bool containsPointer = TR.anyPointer(Val); - - if (containsPointer && Val->getType()->isFPOrFPVectorTy()) - containsPointer = false; - - if (containsPointer && !isValuePotentiallyUsedAsPointer(Val)) { - containsPointer = false; - if (auto Arg = dyn_cast(Val)) { - assert(Arg->hasByValAttr()); - (void)Arg; - InsertConstantValue(TR, Val); - return true; - } - } - - // We do this pointer dance here to ensure that any derived pointers from - // constant arguments are still constant, even id ATA is disabled. - if (EnzymeDisableActivityAnalysis) { - if (!containsPointer) - return false; - - auto TmpOrig = getBaseObject(Val); - - if (auto LI = dyn_cast(TmpOrig)) - return isConstantValue(TR, LI->getPointerOperand()); - if (isNVLoad(TmpOrig)) { - return isConstantValue(TR, cast(TmpOrig)->getOperand(0)); - } - - if (TmpOrig == Val) - return false; - return isConstantValue(TR, TmpOrig); - } - - if (containsPointer) { - // This value is certainly an integer (and only and integer, not a pointer - // or float). Therefore its value is constant - if (TR.query(Val)[{-1, -1}] == BaseType::Integer) { - if (EnzymePrintActivity) - llvm::errs() << " Value const as pointer to integer " << (int)directions - << " " << *Val << " " << TR.query(Val).str() << "\n"; - InsertConstantValue(TR, Val); - return true; - } - - auto TmpOrig = getBaseObject(Val); - - // If we know that our origin is inactive from its arguments, - // we are definitionally inactive - if (directions & UP) { - // If we are derived from an argument our activity is equal to the - // activity of the argument by definition - if (auto arg = dyn_cast(TmpOrig)) { - if (!arg->hasByValAttr()) { - bool res = isConstantValue(TR, TmpOrig); - if (res) { - if (EnzymePrintActivity) - llvm::errs() << " arg const from orig val=" << *Val - << " orig=" << *TmpOrig << "\n"; - InsertConstantValue(TR, Val); - } else { - if (EnzymePrintActivity) - llvm::errs() << " arg active from orig val=" << *Val - << " orig=" << *TmpOrig << "\n"; - ActiveValues.insert(Val); - } - return res; - } - } - - UpHypothesis = - std::unique_ptr(new ActivityAnalyzer(*this, UP)); - UpHypothesis->ConstantValues.insert(Val); - - // If our origin is a load of a known inactive (say inactive argument), we - // are also inactive - if (auto PN = dyn_cast(TmpOrig)) { - // Not taking fast path incase phi is recursive. - Value *active = nullptr; - for (auto &V : PN->incoming_values()) { - if (!UpHypothesis->isConstantValue(TR, V.get())) { - active = V.get(); - break; - } - } - if (!active) { - InsertConstantValue(TR, Val); - if (TmpOrig != Val) { - InsertConstantValue(TR, TmpOrig); - } - insertConstantsFrom(TR, *UpHypothesis); - return true; - } else if (EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[active].insert(Val); - if (TmpOrig != Val) { - ReEvaluateValueIfInactiveValue[active].insert(TmpOrig); - } - } - } else if (auto LI = dyn_cast(TmpOrig)) { - - if (directions == UP) { - if (isConstantValue(TR, LI->getPointerOperand())) { - InsertConstantValue(TR, Val); - return true; - } - } else { - if (UpHypothesis->isConstantValue(TR, LI->getPointerOperand())) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - if (EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[LI->getPointerOperand()].insert(Val); - if (TmpOrig != Val) { - ReEvaluateValueIfInactiveValue[LI->getPointerOperand()].insert( - TmpOrig); - } - } - } else if (isNVLoad(TmpOrig)) { - auto II = cast(TmpOrig); - if (directions == UP) { - if (isConstantValue(TR, II->getOperand(0))) { - InsertConstantValue(TR, Val); - return true; - } - } else { - if (UpHypothesis->isConstantValue(TR, II->getOperand(0))) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - if (EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[II->getOperand(0)].insert(Val); - if (TmpOrig != Val) { - ReEvaluateValueIfInactiveValue[II->getOperand(0)].insert(TmpOrig); - } - } - } else if (auto RMW = dyn_cast(TmpOrig)) { - if (directions == UP) { - if (isConstantValue(TR, RMW->getPointerOperand())) { - InsertConstantValue(TR, Val); - return true; - } - } else { - if (UpHypothesis->isConstantValue(TR, RMW->getPointerOperand())) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - if (EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert(Val); - if (TmpOrig != Val) { - ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert( - TmpOrig); - } - } - } else if (auto op = dyn_cast(TmpOrig)) { - if (isInactiveCall(*op) || op->hasFnAttr("enzyme_inactive_val") || - op->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex, - "enzyme_inactive")) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - auto called = getFunctionFromCall(op); - - StringRef funcName = getFuncNameFromCall(op); - - if (called && - (called->hasFnAttribute("enzyme_inactive_val") || - called->getAttributes().hasAttribute( - llvm::AttributeList::ReturnIndex, "enzyme_inactive"))) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - if (funcName == "free" || funcName == "_ZdlPv" || - funcName == "_ZdlPvm" || funcName == "munmap") { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - - // If requesting empty unknown functions to be considered inactive, - // abide by those rules - if (called && EnzymeEmptyFnInactive && called->empty() && - !hasMetadata(called, "enzyme_gradient") && - !hasMetadata(called, "enzyme_derivative") && - !isAllocationFunction(funcName, TLI) && - !isDeallocationFunction(funcName, TLI) && !isa(op)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - if (isAllocationFunction(funcName, TLI)) { - // This pointer is inactive if it is either not actively stored to - // and not actively loaded from. - if (directions == DOWN) { - for (auto UA : - {UseActivity::OnlyLoads, UseActivity::OnlyNonPointerStores, - UseActivity::AllStores, UseActivity::None}) { - Instruction *LoadReval = nullptr; - if (isValueInactiveFromUsers(TR, TmpOrig, UA, &LoadReval)) { - InsertConstantValue(TR, Val); - return true; - } - if (LoadReval && UA != UseActivity::AllStores && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveInst[LoadReval].insert(TmpOrig); - } - } - } else if (directions & DOWN) { - auto DownHypothesis = std::shared_ptr( - new ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantValues.insert(TmpOrig); - for (auto UA : - {UseActivity::OnlyLoads, UseActivity::OnlyNonPointerStores, - UseActivity::AllStores, UseActivity::None}) { - Instruction *LoadReval = nullptr; - if (DownHypothesis->isValueInactiveFromUsers(TR, TmpOrig, UA, - &LoadReval)) { - insertConstantsFrom(TR, *DownHypothesis); - InsertConstantValue(TR, Val); - return true; - } else { - if (LoadReval && UA != UseActivity::AllStores && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveInst[LoadReval].insert(TmpOrig); - } - } - } - } - // If allocation function doesn't initialize inner pointers. - // For example, julia allocations initialize inner pointers, but - // malloc/etc just allocate the immediate memory. - if (directions & DOWN && - (funcName == "malloc" || funcName == "calloc" || - funcName == "_Znwm" || funcName == "julia.gc_alloc_obj" || - funcName == "??2@YAPAXI@Z" || funcName == "??2@YAPEAX_K@Z" || - funcName == "jl_gc_alloc_typed" || - funcName == "ijl_gc_alloc_typed")) { - std::shared_ptr Hypothesis = - std::shared_ptr( - new ActivityAnalyzer(*this, directions)); - Hypothesis->ActiveValues.insert(Val); - Instruction *LoadReval = nullptr; - if (Hypothesis->isValueInactiveFromUsers( - TR, TmpOrig, UseActivity::OnlyStores, &LoadReval)) { - insertConstantsFrom(TR, *Hypothesis); - InsertConstantValue(TR, Val); - return true; - } else { - if (LoadReval && EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveInst[LoadReval].insert(TmpOrig); - } - } - } - } - if (funcName == "jl_array_copy" || funcName == "ijl_array_copy" || - funcName == "jl_idtable_rehash" || - funcName == "ijl_idtable_rehash" || - funcName == "jl_genericmemory_copy_slice" || - funcName == "ijl_genericmemory_copy_slice") { - // This pointer is inactive if it is either not actively stored to - // and not actively loaded from and the copied input is inactive. - if (directions & DOWN && directions & UP) { - if (UpHypothesis->isConstantValue(TR, op->getOperand(0))) { - auto DownHypothesis = std::shared_ptr( - new ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantValues.insert(TmpOrig); - for (auto UA : - {UseActivity::OnlyLoads, UseActivity::OnlyNonPointerStores, - UseActivity::AllStores, UseActivity::None}) { - Instruction *LoadReval = nullptr; - if (DownHypothesis->isValueInactiveFromUsers(TR, TmpOrig, UA, - &LoadReval)) { - insertConstantsFrom(TR, *DownHypothesis); - InsertConstantValue(TR, Val); - return true; - } else { - if (LoadReval && UA != UseActivity::AllStores && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveInst[LoadReval].insert(TmpOrig); - } - } - } - } - } - } - } else if (isa(Val)) { - // This pointer is inactive if it is either not actively stored to or - // not actively loaded from and is nonescaping by definition of being - // alloca. - // When presuming the value is constant, - // OnlyStores is insufficient. This is because one could allocate - // memory, assumed inactive by definition since it is only stored - // into the hypothesized inactive alloca. However, one could load - // that pointer, and then use it as an active buffer. - // When presuming the value is active, - // OnlyStores should be fine, since any store will assume that - // its use by storing to the active alloca will be active unless - // the pointer being stored is otherwise guaranteed inactive (e.g. - // from the argument). - if (directions == DOWN) { - for (auto UA : - {UseActivity::OnlyLoads, UseActivity::OnlyNonPointerStores, - UseActivity::AllStores, UseActivity::None}) { - Instruction *LoadReval = nullptr; - if (isValueInactiveFromUsers(TR, TmpOrig, UA, &LoadReval)) { - InsertConstantValue(TR, Val); - return true; - } - if (LoadReval && UA != UseActivity::AllStores && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveInst[LoadReval].insert(TmpOrig); - } - } - } else if (directions & DOWN) { - auto DownHypothesis = std::shared_ptr( - new ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantValues.insert(TmpOrig); - for (auto UA : - {UseActivity::OnlyLoads, UseActivity::OnlyNonPointerStores, - UseActivity::AllStores, UseActivity::None}) { - Instruction *LoadReval = nullptr; - if (DownHypothesis->isValueInactiveFromUsers(TR, TmpOrig, UA, - &LoadReval)) { - insertConstantsFrom(TR, *DownHypothesis); - InsertConstantValue(TR, Val); - return true; - } else { - if (LoadReval && UA != UseActivity::AllStores && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveInst[LoadReval].insert(TmpOrig); - } - } - } - } - - if (directions & DOWN) { - std::shared_ptr Hypothesis = - std::shared_ptr( - new ActivityAnalyzer(*this, directions)); - Hypothesis->ActiveValues.insert(Val); - Instruction *LoadReval = nullptr; - if (Hypothesis->isValueInactiveFromUsers( - TR, TmpOrig, UseActivity::OnlyStores, &LoadReval)) { - insertConstantsFrom(TR, *Hypothesis); - InsertConstantValue(TR, Val); - return true; - } else { - if (LoadReval && EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveInst[LoadReval].insert(TmpOrig); - } - } - } - } - - // otherwise if the origin is a previously derived known inactive value - // assess - // TODO here we would need to potentially consider loading an active - // global as we again assume that active memory is passed explicitly as an - // argument - if (TmpOrig != Val) { - if (isConstantValue(TR, TmpOrig)) { - if (EnzymePrintActivity) - llvm::errs() << " Potential Pointer(" << (int)directions << ") " - << *Val << " inactive from inactive origin " - << *TmpOrig << "\n"; - InsertConstantValue(TR, Val); - return true; - } - } - if (auto inst = dyn_cast(Val)) { - if (!inst->mayReadFromMemory() && !isa(Val)) { - if (directions == UP && !isa(inst)) { - if (isInstructionInactiveFromOrigin(TR, inst, true)) { - InsertConstantValue(TR, Val); - return true; - } - } else { - if (UpHypothesis->isInstructionInactiveFromOrigin(TR, inst, true)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } else if (directions == 3) { - for (auto &op : inst->operands()) { - if (!UpHypothesis->isConstantValue(TR, op) && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[op].insert(Val); - } - } - } - } - } - } - } - - // If not capable of looking at both users and uses, all the ways a pointer - // can be loaded/stored cannot be assesed and therefore we default to assume - // it to be active - if (directions != 3) { - if (EnzymePrintActivity) - llvm::errs() << " " << *Val << "\n"; - ActiveValues.insert(Val); - return false; - } - - if (EnzymePrintActivity) - llvm::errs() << " < MEMSEARCH" << (int)directions << ">" << *Val << "\n"; - // A pointer value is active if two things hold: - // an potentially active value is stored into the memory - // memory loaded from the value is used in an active way - Instruction *potentiallyActiveStore = nullptr; - Instruction *potentialStore = nullptr; - Instruction *potentiallyActiveLoad = nullptr; - - // Assume the value (not instruction) is itself active - // In spite of that can we show that there are either no active stores - // or no active loads - auto Hypothesis = std::unique_ptr( - new ActivityAnalyzer(*this, directions)); - Hypothesis->ActiveValues.insert(Val); - if (auto VI = dyn_cast(Val)) { - if (UpHypothesis->isInstructionInactiveFromOrigin(TR, VI, true)) { - Hypothesis->DeducingPointers.insert(Val); - if (EnzymePrintActivity) - llvm::errs() << " constant instruction hypothesis: " << *VI << "\n"; - } else { - if (EnzymePrintActivity) - llvm::errs() << " cannot show constant instruction hypothesis: " - << *VI << "\n"; - if (directions == 3) { - for (auto &op : VI->operands()) { - if (!UpHypothesis->isConstantValue(TR, op) && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[op].insert(Val); - } - } - } - } - } - - auto checkActivity = [&](Instruction *I) { - if (notForAnalysis.count(I->getParent())) - return false; - - if (isa(I)) - return false; - - // If this is a malloc or free, this doesn't impact the activity - if (auto CI = dyn_cast(I)) { - if (isInactiveCallInst(*CI, TLI)) - return false; - - StringRef funcName = getFuncNameFromCall(CI); - if (isMemFreeLibMFunction(funcName)) { - return false; - } - - if (funcName == "__cxa_guard_acquire" || - funcName == "__cxa_guard_release" || - funcName == "__cxa_guard_abort" || funcName == "posix_memalign" || - funcName == "cuMemAllocAsync" || funcName == "cuMemAlloc" || - funcName == "cuMemAlloc_v2" || funcName == "cudaMallocAsync" || - funcName == "cudaMallocHost" || - funcName == "cudaMallocFromPoolAsync") { - return false; - } - } - - Value *memval = Val; - - // BasicAA stupidy assumes that non-pointer's don't alias - // if this is a nonpointer, use something else to force alias - // consideration - if (!memval->getType()->isPointerTy()) { - if (auto ci = dyn_cast(Val)) { - if (ci->getOperand(0)->getType()->isPointerTy()) { - memval = ci->getOperand(0); - } - } - for (auto user : Val->users()) { - if (isa(user) && user->getType()->isPointerTy()) { - memval = user; - break; - } - } - } - - auto AARes = AA.getModRefInfo( - I, MemoryLocation(memval, LocationSize::beforeOrAfterPointer())); - - // Still having failed to replace the location used by AA, fall back to - // getModref against any location. - if (!memval->getType()->isPointerTy()) { - if (auto CB = dyn_cast(I)) { -#if LLVM_VERSION_MAJOR >= 16 - AARes = AA.getMemoryEffects(CB).getModRef(); -#else - AARes = createModRefInfo(AA.getModRefBehavior(CB)); -#endif - } else { - bool mayRead = I->mayReadFromMemory(); - bool mayWrite = I->mayWriteToMemory(); - AARes = mayRead ? (mayWrite ? ModRefInfo::ModRef : ModRefInfo::Ref) - : (mayWrite ? ModRefInfo::Mod : ModRefInfo::NoModRef); - } - } - - if (auto CB = dyn_cast(I)) { - if (CB->onlyAccessesInaccessibleMemory()) - AARes = ModRefInfo::NoModRef; - - bool ReadOnly = isReadOnly(CB); - - bool WriteOnly = isWriteOnly(CB); - - if (ReadOnly && WriteOnly) - AARes = ModRefInfo::NoModRef; - else if (WriteOnly) { - if (isRefSet(AARes)) { - AARes = isModSet(AARes) ? ModRefInfo::Mod : ModRefInfo::NoModRef; - } - } else if (ReadOnly) { - if (isModSet(AARes)) { - AARes = isRefSet(AARes) ? ModRefInfo::Ref : ModRefInfo::NoModRef; - } - } - } - - // TODO this aliasing information is too conservative, the question - // isn't merely aliasing but whether there is a path for THIS value to - // eventually be loaded by it not simply because there isnt aliasing - - // If we haven't already shown a potentially active load - // check if this loads the given value and is active - if ((!potentiallyActiveLoad || !potentiallyActiveStore) && - isRefSet(AARes)) { - if (EnzymePrintActivity) - llvm::errs() << "potential active load: " << *I << "\n"; - if (isa(I) || isNVLoad(I) || isa(I)) { - // If the ref'ing value is a load check if the loaded value is - // active - if (!Hypothesis->isConstantValue(TR, I)) { - potentiallyActiveLoad = I; - // returns whether seen - std::function &)> - loadCheck = [&](Value *V, SmallPtrSetImpl &Seen) { - if (Seen.count(V)) - return false; - Seen.insert(V); - if (TR.anyPointer(V)) { - for (auto UU : V->users()) { - auto U = cast(UU); - if (U->mayWriteToMemory()) { - if (!Hypothesis->isConstantInstruction(TR, U)) { - if (EnzymePrintActivity) - llvm::errs() << "potential active store via " - "pointer in load: " - << *I << " of " << *Val << " via " - << *U << "\n"; - potentiallyActiveStore = U; - return true; - } - } - - if (U != Val && !Hypothesis->isConstantValue(TR, U)) { - if (loadCheck(U, Seen)) - return true; - } - } - } - return false; - }; - SmallPtrSet Seen; - loadCheck(I, Seen); - } - } else if (auto MTI = dyn_cast(I)) { - if (!Hypothesis->isConstantValue(TR, MTI->getArgOperand(0))) { - potentiallyActiveLoad = MTI; - if (TR.query(Val)[{-1, -1}].isPossiblePointer()) { - if (EnzymePrintActivity) - llvm::errs() - << "potential active store via pointer in memcpy: " << *I - << " of " << *Val << "\n"; - potentiallyActiveStore = MTI; - } - } - } else { - // Otherwise fallback and check any part of the instruction is - // active - // TODO: note that this can be optimized (especially for function - // calls) - // Notably need both to check the result and instruction since - // A load that has as result an active pointer is not an active - // instruction, but does have an active value - if (!Hypothesis->isConstantInstruction(TR, I) || - (I != Val && !Hypothesis->isConstantValue(TR, I))) { - potentiallyActiveLoad = I; - // If this a potential pointer of pointer AND - // double** Val; - // - if (TR.query(Val)[{-1, -1}].isPossiblePointer()) { - // If this instruction either: - // 1) can actively store into the inner pointer, even - // if it doesn't store into the outer pointer. Actively - // storing into the outer pointer is handled by the isMod - // case. - // I(double** readonly Val, double activeX) { - // double* V0 = Val[0] - // V0 = activeX; - // } - // 2) may return an active pointer loaded from Val - // double* I = *Val; - // I[0] = active; - // - if ((I->mayWriteToMemory() && - !Hypothesis->isConstantInstruction(TR, I)) || - (!Hypothesis->DeducingPointers.count(I) && - !Hypothesis->isConstantValue(TR, I) && TR.anyPointer(I))) { - if (EnzymePrintActivity) - llvm::errs() << "potential active store via pointer in " - "unknown inst: " - << *I << " of " << *Val << "\n"; - potentiallyActiveStore = I; - } - } - } - } - } - if ((!potentiallyActiveStore || !potentialStore) && isModSet(AARes)) { - if (EnzymePrintActivity) - llvm::errs() << "potential active store: " << *I << " Val=" << *Val - << "\n"; - if (auto SI = dyn_cast(I)) { - bool cop = !Hypothesis->isConstantValue(TR, SI->getValueOperand()); - // bool cop2 = !Hypothesis->isConstantValue(TR, - // SI->getPointerOperand()); - if (EnzymePrintActivity) - llvm::errs() << " -- store potential activity: " << (int)cop - << " - " << *SI << " of " - << " Val=" << *Val << "\n"; - potentialStore = I; - if (cop) // && cop2) - potentiallyActiveStore = SI; - } else if (auto MTI = dyn_cast(I)) { - bool cop = !Hypothesis->isConstantValue(TR, MTI->getArgOperand(1)); - potentialStore = I; - if (cop) - potentiallyActiveStore = MTI; - } else if (isa(I)) { - potentialStore = I; - } else { - // Otherwise fallback and check if the instruction is active - // TODO: note that this can be optimized (especially for function - // calls) - auto cop = !Hypothesis->isConstantInstruction(TR, I); - if (EnzymePrintActivity) - llvm::errs() << " -- unknown store potential activity: " << (int)cop - << " - " << *I << " of " - << " Val=" << *Val << "\n"; - potentialStore = I; - if (cop) - potentiallyActiveStore = I; - } - } - if (potentiallyActiveStore && potentiallyActiveLoad) - return true; - return false; - }; - - // Search through all the instructions in this function - // for potential loads / stores of this value. - // - // We can choose to only look at potential follower instructions - // if the value is created by the instruction (alloca, noalias) - // since no potentially active store to the same location can occur - // prior to its creation. Otherwise, check all instructions in the - // function as a store to an aliasing location may have occured - // prior to the instruction generating the value. - - if (auto VI = dyn_cast(Val)) { - allFollowersOf(VI, checkActivity); - } else if (auto VI = dyn_cast(Val)) { - if (VI->hasRetAttr(Attribute::NoAlias)) - allFollowersOf(VI, checkActivity); - else { - for (BasicBlock &BB : *TR.getFunction()) { - if (notForAnalysis.count(&BB)) - continue; - for (Instruction &I : BB) { - if (checkActivity(&I)) - goto activeLoadAndStore; - } - } - } - } else if (isa(Val) || isa(Val)) { - for (BasicBlock &BB : *TR.getFunction()) { - if (notForAnalysis.count(&BB)) - continue; - for (Instruction &I : BB) { - if (checkActivity(&I)) - goto activeLoadAndStore; - } - } - } else { - llvm::errs() << "unknown pointer value type: " << *Val << "\n"; - assert(0 && "unknown pointer value type"); - llvm_unreachable("unknown pointer value type"); - } - - activeLoadAndStore:; - if (EnzymePrintActivity) { - llvm::errs() << " " << *Val - << " potentiallyActiveLoad="; - if (potentiallyActiveLoad) - llvm::errs() << *potentiallyActiveLoad; - else - llvm::errs() << potentiallyActiveLoad; - llvm::errs() << " potentiallyActiveStore="; - if (potentiallyActiveStore) - llvm::errs() << *potentiallyActiveStore; - else - llvm::errs() << potentiallyActiveStore; - llvm::errs() << " potentialStore="; - if (potentialStore) - llvm::errs() << *potentialStore; - else - llvm::errs() << potentialStore; - llvm::errs() << "\n"; - } - if (potentiallyActiveLoad && potentiallyActiveStore) { - if (EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveInst[potentiallyActiveLoad].insert(Val); - ReEvaluateValueIfInactiveInst[potentiallyActiveStore].insert(Val); - } - insertAllFrom(TR, *Hypothesis, Val, TmpOrig); - if (TmpOrig != Val && EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[TmpOrig].insert(Val); - ReEvaluateValueIfInactiveInst[potentiallyActiveLoad].insert(TmpOrig); - ReEvaluateValueIfInactiveInst[potentiallyActiveStore].insert(TmpOrig); - } - return false; - } else { - // We now know that there isn't a matching active load/store pair in this - // function. Now the only way that this memory can facilitate a transfer - // of active information is if it is done outside of the function - - // This can happen if either: - // a) the memory had an active load or store before this function was - // called b) the memory had an active load or store after this function - // was called - - // Case a) can occur if: - // 1) this memory came from an active global - // 2) this memory came from an active argument - // 3) this memory came from a load from active memory - // In other words, assuming this value is inactive, going up this - // location's argument must be inactive - - assert(UpHypothesis); - // UpHypothesis.ConstantValues.insert(val); - if (DeducingPointers.size() == 0) - UpHypothesis->insertConstantsFrom(TR, *Hypothesis); - assert(directions & UP); - bool ActiveUp = - !isa(Val) && - !UpHypothesis->isInstructionInactiveFromOrigin(TR, Val, true); - - // Case b) can occur if: - // 1) this memory is used as part of an active return - // 2) this memory is stored somewhere - - // We never verify that an origin wasn't stored somewhere or returned. - // to remedy correctness for now let's do something extremely simple - auto DownHypothesis = - std::unique_ptr(new ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantValues.insert(Val); - DownHypothesis->insertConstantsFrom(TR, *Hypothesis); - bool ActiveDown = - DownHypothesis->isValueActivelyStoredOrReturned(TR, Val); - // BEGIN TEMPORARY - - if (!ActiveDown && TmpOrig != Val) { - - if (isa(TmpOrig) || isa(TmpOrig) || - isa(TmpOrig) || isAllocationCall(TmpOrig, TLI)) { - auto DownHypothesis2 = std::unique_ptr( - new ActivityAnalyzer(*DownHypothesis, DOWN)); - DownHypothesis2->ConstantValues.insert(TmpOrig); - if (DownHypothesis2->isValueActivelyStoredOrReturned(TR, TmpOrig)) { - if (EnzymePrintActivity) - llvm::errs() << " active from ivasor: " << *TmpOrig << "\n"; - ActiveDown = true; - } - } else { - // unknown origin that could've been stored/returned/etc - if (EnzymePrintActivity) - llvm::errs() << " active from unknown origin: " << *TmpOrig << "\n"; - ActiveDown = true; - } - } - - // END TEMPORARY - - // We can now consider the three places derivative information can be - // transferred - // Case A) From the origin - // Case B) Though the return - // Case C) Within the function (via either load or store) - - bool ActiveMemory = false; - - // If it is transferred via active origin and return, clearly this is - // active - ActiveMemory |= (ActiveUp && ActiveDown); - - // If we come from an active origin and load, memory is clearly active - ActiveMemory |= (ActiveUp && potentiallyActiveLoad); - - // If we come from an active origin and only store into it, it changes - // future state - ActiveMemory |= (ActiveUp && potentialStore); - - // If we go to an active return and store active memory, this is active - ActiveMemory |= (ActiveDown && potentialStore); - // Actually more generally, if we are ActiveDown (returning memory that is - // used) in active return, we must be active. This is necessary to ensure - // mallocs have their differential shadows created when returned [TODO - // investigate more] - ActiveMemory |= ActiveDown; - - // If we go to an active return and only load it, however, that doesnt - // transfer derivatives and we can say this memory is inactive - - if (EnzymePrintActivity) - llvm::errs() << " @@MEMSEARCH" << (int)directions << ">" << *Val - << " potentiallyActiveLoad=" << potentiallyActiveLoad - << " potentialStore=" << potentialStore - << " ActiveUp=" << ActiveUp << " ActiveDown=" << ActiveDown - << " ActiveMemory=" << ActiveMemory << "\n"; - - if (ActiveMemory) { - ActiveValues.insert(Val); - assert(Hypothesis->directions == directions); - assert(Hypothesis->ActiveValues.count(Val)); - insertAllFrom(TR, *Hypothesis, Val, TmpOrig); - if (TmpOrig != Val && EnzymeEnableRecursiveHypotheses) - ReEvaluateValueIfInactiveValue[TmpOrig].insert(Val); - return false; - } else { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *Hypothesis); - if (DeducingPointers.size() == 0) - insertConstantsFrom(TR, *UpHypothesis); - insertConstantsFrom(TR, *DownHypothesis); - return true; - } - } - } - - // For all non-pointers, it is now sufficient to simply prove that - // either activity does not flow in, or activity does not flow out - // This alone cuts off the flow (being unable to flow through memory) - - // Not looking at uses to prove inactive (definition of up), if the creator of - // this value is inactive, we are inactive Since we won't look at uses to - // prove, we can inductively assume this is inactive - if (directions & UP) { - if (!UpHypothesis) - UpHypothesis = - std::unique_ptr(new ActivityAnalyzer(*this, UP)); - if (directions == UP && !isa(Val)) { - if (isInstructionInactiveFromOrigin(TR, Val, true)) { - InsertConstantValue(TR, Val); - return true; - } else if (auto I = dyn_cast(Val)) { - if (directions == 3) { - for (auto &op : I->operands()) { - if (!UpHypothesis->isConstantValue(TR, op) && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[op].insert(I); - } - } - } - } - } else { - UpHypothesis->ConstantValues.insert(Val); - if (UpHypothesis->isInstructionInactiveFromOrigin(TR, Val, true)) { - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } else if (auto I = dyn_cast(Val)) { - if (directions == 3) { - for (auto &op : I->operands()) { - if (!UpHypothesis->isConstantValue(TR, op) && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[op].insert(I); - } - } - } - } - } - } - - if (directions & DOWN) { - // Not looking at users to prove inactive (definition of down) - // If all users are inactive, this is therefore inactive. - // Since we won't look at origins to prove, we can inductively assume this - // is inactive - - // As an optimization if we are going down already - // and we won't use ourselves (done by PHI's), we - // dont need to inductively assume we're true - // and can instead use this object! - if (directions == DOWN && !isa(Val)) { - if (isValueInactiveFromUsers(TR, Val, UseActivity::None)) { - if (UpHypothesis) - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } - } else { - auto DownHypothesis = - std::unique_ptr(new ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantValues.insert(Val); - if (DownHypothesis->isValueInactiveFromUsers(TR, Val, - UseActivity::None)) { - insertConstantsFrom(TR, *DownHypothesis); - if (UpHypothesis) - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } - } - } - - if (EnzymePrintActivity) - llvm::errs() << " Value nonconstant (couldn't disprove)[" << (int)directions - << "]" << *Val << "\n"; - ActiveValues.insert(Val); - return false; -} - -/// Is the instruction guaranteed to be inactive because of its operands -bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, - llvm::Value *val, - bool considerValue) { - // Must be an analyzer only searching up - assert(directions == UP); - assert(!isa(val)); - assert(!isa(val)); - - // Not an instruction and thus not legal to search for activity via operands - if (!isa(val)) { - llvm::errs() << "unknown pointer source: " << *val << "\n"; - assert(0 && "unknown pointer source"); - llvm_unreachable("unknown pointer source"); - return false; - } - - Instruction *inst = cast(val); - if (EnzymePrintActivity) - llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst << "\n"; - - // cpuid is explicitly an inactive instruction - if (auto call = dyn_cast(inst)) { - if (auto iasm = dyn_cast(call->getCalledOperand())) { - if (StringRef(iasm->getAsmString()).contains("cpuid")) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction from known cpuid instruction " - << *inst << "\n"; - return true; - } - } - } - - if (auto SI = dyn_cast(inst)) { - // if either src or dst is inactive, there cannot be a transfer of active - // values and thus the store is inactive - if (isConstantValue(TR, SI->getValueOperand()) || - isConstantValue(TR, SI->getPointerOperand())) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction as store operand is inactive " - << *inst << "\n"; - return true; - } - } - - if (!considerValue) { - if (auto IEI = dyn_cast(inst)) { - if ((!TR.anyFloat(IEI->getOperand(0)) || - isConstantValue(TR, IEI->getOperand(0))) && - (!TR.anyFloat(IEI->getOperand(1)) || - isConstantValue(TR, IEI->getOperand(1)))) { - if (EnzymePrintActivity) - llvm::errs() - << " constant instruction as inserting known pointer or inactive" - << *inst << "\n"; - return true; - } - } - if (auto IEI = dyn_cast(inst)) { - if ((!TR.anyFloat(IEI->getAggregateOperand()) || - isConstantValue(TR, IEI->getAggregateOperand())) && - (!TR.anyFloat(IEI->getInsertedValueOperand()) || - isConstantValue(TR, IEI->getInsertedValueOperand()))) { - if (EnzymePrintActivity) - llvm::errs() - << " constant instruction as inserting known pointer or inactive" - << *inst << "\n"; - return true; - } - } - if (auto PN = dyn_cast(inst)) { - std::deque todo = {PN}; - SmallPtrSet done; - SmallVector incoming; - while (todo.size()) { - auto cur = todo.back(); - todo.pop_back(); - if (done.count(cur)) - continue; - done.insert(cur); - for (auto &V : cur->incoming_values()) { - if (auto P = dyn_cast(V)) { - todo.push_back(P); - continue; - } - incoming.push_back(V); - } - } - bool legal = true; - for (auto V : incoming) { - if (TR.anyFloat(V) && !isConstantValue(TR, V)) { - legal = false; - break; - } - } - if (legal) { - if (EnzymePrintActivity) - llvm::errs() - << " constant instruction as phi of known pointer or inactive" - << *inst << "\n"; - return true; - } - } - } - - if (auto MTI = dyn_cast(inst)) { - // if either src or dst is inactive, there cannot be a transfer of active - // values and thus the store is inactive - if (isConstantValue(TR, MTI->getArgOperand(0)) || - isConstantValue(TR, MTI->getArgOperand(1))) { - if (EnzymePrintActivity) - llvm::errs() << " constant instruction as memtransfer " << *inst - << "\n"; - return true; - } - } - - if (auto op = dyn_cast(inst)) { - if (isInactiveCall(*op)) - return true; - - if (op->hasFnAttr("enzyme_inactive_val")) { - return true; - } - // Calls to print/assert/cxa guard are definitionally inactive - llvm::Value *callVal; - callVal = op->getCalledOperand(); - StringRef funcName = getFuncNameFromCall(op); - auto called = getFunctionFromCall(op); - - if (called && (called->hasFnAttribute("enzyme_inactive_val"))) { - return true; - } - if (funcName == "free" || funcName == "_ZdlPv" || funcName == "_ZdlPvm" || - funcName == "munmap") { - return true; - } - - // If requesting empty unknown functions to be considered inactive, abide - // by those rules - if (called && EnzymeEmptyFnInactive && called->empty() && - !hasMetadata(called, "enzyme_gradient") && - !hasMetadata(called, "enzyme_derivative") && - !isAllocationFunction(funcName, TLI) && - !isDeallocationFunction(funcName, TLI) && !isa(op)) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-emptyconst " - << *inst << "\n"; - return true; - } - if (!isa(callVal) && isConstantValue(TR, callVal)) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-constfn " - << *inst << " - " << *callVal << "\n"; - return true; - } - } - // Intrinsics known always to be inactive - if (auto II = dyn_cast(inst)) { - if (isIntelSubscriptIntrinsic(*II)) { - // The only argument that can make an llvm.intel.subscript intrinsic - // active is the pointer operand - const unsigned int ptrArgIdx = 3; - if (isConstantValue(TR, II->getOperand(ptrArgIdx))) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-intrinsic " - << *inst << "\n"; - return true; - } - return false; - } - } - - if (auto gep = dyn_cast(inst)) { - // A gep's only args that could make it active is the pointer operand - if (isConstantValue(TR, gep->getPointerOperand())) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-gep " << *inst - << "\n"; - return true; - } - return false; - } else if (auto ci = dyn_cast(inst)) { - bool seenuse = false; - - propagateArgumentInformation(TLI, *ci, [&](Value *a) { - if (!isConstantValue(TR, a)) { - seenuse = true; - if (EnzymePrintActivity) - llvm::errs() << "nonconstant(" << (int)directions << ") up-call " - << *inst << " op " << *a << "\n"; - return true; - } - return false; - }); - if (EnzymeGlobalActivity) { - if (!ci->onlyAccessesArgMemory() && !ci->doesNotAccessMemory()) { - bool legalUse = false; - - StringRef funcName = getFuncNameFromCall(ci); - - if (funcName == "") { - } else if (isMemFreeLibMFunction(funcName) || - isDebugFunction(ci->getCalledFunction()) || - isCertainPrint(funcName) || - isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - legalUse = true; - } - - if (!legalUse) { - if (EnzymePrintActivity) - llvm::errs() << "nonconstant(" << (int)directions << ") up-global " - << *inst << "\n"; - seenuse = true; - } - } - } - - if (!seenuse) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-call:" << *inst - << "\n"; - return true; - } - return !seenuse; - } else if (auto si = dyn_cast(inst)) { - - if (isConstantValue(TR, si->getTrueValue()) && - isConstantValue(TR, si->getFalseValue())) { - - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-sel:" << *inst - << "\n"; - return true; - } - return false; - } else if (isa(inst) || isa(inst) || - isa(inst) || isa(inst)) { - - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-fpcst:" << *inst - << "\n"; - return true; - } else { - bool seenuse = false; - //! TODO does not consider reading from global memory that is active and not - //! an argument - for (auto &a : inst->operands()) { - bool hypval = isConstantValue(TR, a); - if (!hypval) { - if (EnzymePrintActivity) - llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " - << *inst << " op " << *a << "\n"; - seenuse = true; - break; - } - } - - if (!seenuse) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-inst:" << *inst - << "\n"; - return true; - } - return false; - } -} - -/// Is the value free of any active uses -bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR, - llvm::Value *const val, - UseActivity PUA, - Instruction **FoundInst) { - assert(directions & DOWN); - // Must be an analyzer only searching down, unless used outside - // assert(directions == DOWN); - - // To ensure we can call down - - if (EnzymePrintActivity) - llvm::errs() << " " << *val - << " UA=" << to_string(PUA) << "\n"; - - bool seenuse = false; - // user, predecessor - std::deque> todo; - for (const auto a : val->users()) { - todo.push_back(std::make_tuple(a, val, PUA)); - } - std::set> done = {}; - - SmallSet AllocaSet; - - if (isa(val)) - AllocaSet.insert(val); - - if (PUA == UseActivity::None && isAllocationCall(val, TLI)) - AllocaSet.insert(val); - - while (todo.size()) { - auto pair = todo.front(); - todo.pop_front(); - if (done.count(pair)) - continue; - done.insert(pair); - User *a = std::get<0>(pair); - Value *parent = std::get<1>(pair); - UseActivity UA = std::get<2>(pair); - - if (auto LI = dyn_cast(a)) { - if (UA == UseActivity::OnlyStores) - continue; - if (UA == UseActivity::OnlyNonPointerStores || - UA == UseActivity::AllStores) { - if (!TR.anyPointer(LI)) - continue; - } - } - - if (EnzymePrintActivity) - llvm::errs() << " considering use of " << *val << " - " << *a - << "\n"; - - // Only ignore stores to the operand, not storing the operand - // somewhere - if (auto SI = dyn_cast(a)) { - if (SI->getValueOperand() != parent) { - if (UA == UseActivity::OnlyLoads) { - continue; - } - if (UA != UseActivity::AllStores && - (ConstantValues.count(SI->getValueOperand()) || - isa(SI->getValueOperand()))) - continue; - if (UA == UseActivity::None || - UA == UseActivity::OnlyNonPointerStores) { - // If storing into itself, all potential uses are taken care of - // elsewhere in the recursion. - bool shouldContinue = true; - SmallVector vtodo = {SI->getValueOperand()}; - SmallSet seen; - SmallSet newAllocaSet; - while (vtodo.size()) { - auto TmpOrig = vtodo.back(); - vtodo.pop_back(); - if (seen.count(TmpOrig)) - continue; - seen.insert(TmpOrig); - if (AllocaSet.count(TmpOrig)) { - continue; - } - // We are literally storing our value into ourselves [or relevant - // derived pointer] - if (TmpOrig == val) { - continue; - } - if (isa(TmpOrig)) { - newAllocaSet.insert(TmpOrig); - continue; - } - if (isAllocationCall(TmpOrig, TLI)) { - newAllocaSet.insert(TmpOrig); - continue; - } - if (isa(TmpOrig) || isa(TmpOrig) || - isa(TmpOrig) || isa(TmpOrig)) { - continue; - } - if (auto LI = dyn_cast(TmpOrig)) { - vtodo.push_back(LI->getPointerOperand()); - continue; - } - if (auto CD = dyn_cast(TmpOrig)) { - for (size_t i = 0, len = CD->getNumElements(); i < len; i++) - vtodo.push_back(CD->getElementAsConstant(i)); - continue; - } - if (auto CD = dyn_cast(TmpOrig)) { - for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) - vtodo.push_back(CD->getOperand(i)); - continue; - } - if (auto GV = dyn_cast(TmpOrig)) { - // If operating under the assumption globals are inactive unless - // explicitly marked as active, this is inactive - if (!hasMetadata(GV, "enzyme_shadow") && - EnzymeNonmarkedGlobalsInactive) { - continue; - } - if (hasMetadata(GV, "enzyme_inactive")) { - continue; - } - if (GV->getName().contains("enzyme_const") || - InactiveGlobals.count(GV->getName())) { - continue; - } - } - auto TmpOrig_2 = getBaseObjects(TmpOrig); - if (TmpOrig_2.size() != 1 || TmpOrig != TmpOrig_2[0]) { - for (auto v : TmpOrig_2) - vtodo.push_back(v); - continue; - } - if (UA == PUA && TmpOrig == val) { - continue; - } - if (EnzymePrintActivity) - llvm::errs() << " -- cannot continuing indirect store from " - << *val << " due to " << *TmpOrig << "\n"; - shouldContinue = false; - break; - } - if (shouldContinue) { - if (EnzymePrintActivity) - llvm::errs() << " -- continuing indirect store from " << *val - << " into:\n"; - done.insert(std::make_tuple((User *)SI, SI->getValueOperand(), UA)); - for (auto TmpOrig : newAllocaSet) { - - for (const auto a : TmpOrig->users()) { - todo.push_back(std::make_tuple(a, TmpOrig, UA)); - if (EnzymePrintActivity) - llvm::errs() << " ** " << *a << "\n"; - } - AllocaSet.insert(TmpOrig); - shouldContinue = true; - } - continue; - } - } - } - if (SI->getPointerOperand() != parent) { - // If storing into itself, all potential uses are taken care of - // elsewhere in the recursion. - bool shouldContinue = true; - SmallVector vtodo = {SI->getPointerOperand()}; - SmallSet seen; - if (EnzymePrintActivity) - llvm::errs() << " @@ analyzing store2 " << *SI << " " << *val - << " via " << *SI->getPointerOperand() << "\n"; - while (vtodo.size()) { - auto TmpOrig = vtodo.back(); - vtodo.pop_back(); - if (seen.count(TmpOrig)) - continue; - seen.insert(TmpOrig); - - if (AllocaSet.count(TmpOrig)) { - if (EnzymePrintActivity) - llvm::errs() - << " -- continuing indirect store2(seen alloca) from " - << *val << " via " << *TmpOrig << "\n"; - continue; - } - if (isa(TmpOrig) || isAllocationCall(TmpOrig, TLI)) { - done.insert( - std::make_tuple((User *)SI, SI->getPointerOperand(), UA)); - // If we are capturing a variable v, we need to check any loads or - // stores into that variable, even if we are checking only for - // stores. - auto UA2 = UA; - if (UA == UseActivity::OnlyStores || - UA == UseActivity::OnlyNonPointerStores || - UA == UseActivity::AllStores) - UA2 = UseActivity::None; - for (const auto a : TmpOrig->users()) { - todo.push_back(std::make_tuple(a, TmpOrig, UA2)); - } - AllocaSet.insert(TmpOrig); - if (EnzymePrintActivity) - llvm::errs() - << " -- continuing indirect store2(allocation) from " - << *val << " via " << *TmpOrig << "\n"; - continue; - } - if (PUA == UseActivity::None) { - if (auto LI = dyn_cast(TmpOrig)) { - vtodo.push_back(LI->getPointerOperand()); - if (EnzymePrintActivity) - llvm::errs() - << " -- continuing indirect store2(load) from " << *val - << " via " << *TmpOrig << "\n"; - continue; - } - } - - auto TmpOrig_2 = getBaseObjects(TmpOrig); - if (TmpOrig_2.size() != 1 || TmpOrig != TmpOrig_2[0]) { - for (auto v : TmpOrig_2) { - if (EnzymePrintActivity) - llvm::errs() - << " -- continuing indirect store2(base) from " << *val - << " via " << *TmpOrig << " - v:" << *v << "\n"; - vtodo.push_back(v); - } - continue; - } - if (UA == PUA && TmpOrig == val) { - continue; - } - if (EnzymePrintActivity) - llvm::errs() << " -- failed to continue indirect store2 from " - << *val << " via " << *TmpOrig_2[0] << "\n"; - shouldContinue = false; - break; - } - if (shouldContinue) { - continue; - } - } - if (PUA == UseActivity::OnlyLoads) { - auto TmpOrig = getBaseObjects(SI->getPointerOperand()); - bool AllVals = true; - for (auto v : TmpOrig) { - if (v != val) { - AllVals = false; - break; - } - } - if (AllVals) { - continue; - } - } - } - - if (!isa(a)) { - if (auto CE = dyn_cast(a)) { - for (auto u : CE->users()) { - todo.push_back(std::make_tuple(u, (Value *)CE, UA)); - } - continue; - } - if (isa(a)) { - continue; - } - - if (EnzymePrintActivity) - llvm::errs() << " unknown non instruction use of " << *val << " - " - << *a << "\n"; - goto endloop; - } - - if (isa(a)) { - if (EnzymePrintActivity) - llvm::errs() << "found constant(" << (int)directions - << ") allocainst use:" << *val << " user " << *a << "\n"; - continue; - } - - if (isa(a) || isa(a) || isa(a) || - isa(a)) { - if (EnzymePrintActivity) - llvm::errs() << "found constant(" << (int)directions - << ") si-fp use:" << *val << " user " << *a << "\n"; - continue; - } - - // if this instruction is in a different function, conservatively assume - // it is active - { - Function *InstF = cast(a)->getParent()->getParent(); - while (PPC.CloneOrigin.find(InstF) != PPC.CloneOrigin.end()) - InstF = PPC.CloneOrigin[InstF]; - - Function *F = TR.getFunction(); - while (PPC.CloneOrigin.find(F) != PPC.CloneOrigin.end()) - F = PPC.CloneOrigin[F]; - - if (InstF != F) { - if (EnzymePrintActivity) - llvm::errs() << "found use in different function(" << (int)directions - << ") val:" << *val << " user " << *a << " in " - << InstF->getName() << "@" << InstF - << " self: " << F->getName() << "@" << F << "\n"; - goto endloop; - } - } - if (cast(a)->getParent()->getParent() != TR.getFunction()) - continue; - - // This use is only active if specified - if (isa(a)) { - if (ActiveReturns == DIFFE_TYPE::CONSTANT && - UA != UseActivity::AllStores) { - continue; - } else { - goto endloop; - } - } - - if (auto II = dyn_cast(a)) { - if (isIntelSubscriptIntrinsic(*II) && - (II->getOperand(/*ptrArgIdx=*/3) != parent)) { - continue; - } - } - - if (auto call = dyn_cast(a)) { - bool mayWrite = false; - bool mayRead = false; - bool mayCapture = false; - - auto F = getFunctionFromCall(call); - - size_t idx = 0; - for (auto &arg : call->args()) { - if (arg != parent) { - idx++; - continue; - } - - bool NoCapture = isNoCapture(call, idx); - - mayCapture |= !NoCapture; - - bool ReadOnly = isReadOnly(call, idx); - - mayWrite |= !ReadOnly; - - bool WriteOnly = isWriteOnly(call, idx); - - mayRead |= !WriteOnly; - } - - bool ConstantArg = isFunctionArgumentConstant(call, parent); - if (ConstantArg && UA != UseActivity::AllStores) { - if (EnzymePrintActivity) { - llvm::errs() << "Value found constant callinst use:" << *val - << " user " << *call << " parent " << *parent << "\n"; - } - continue; - } - - if (!mayCapture) { - if (!mayRead && UA == UseActivity::OnlyLoads) { - if (EnzymePrintActivity) { - llvm::errs() << "Value found non-loading use:" << *val << " user " - << *call << "\n"; - } - continue; - } - if (!mayWrite && UA == UseActivity::OnlyStores) { - if (EnzymePrintActivity) { - llvm::errs() << "Value found non-writing use:" << *val << " user " - << *call << "\n"; - } - continue; - } - if (!mayWrite && (UA == UseActivity::OnlyNonPointerStores || - UA == UseActivity::AllStores)) { - if (!mayRead || !TR.query(parent)[{-1, -1}].isPossiblePointer()) { - if (EnzymePrintActivity) { - llvm::errs() - << "Value found non-writing and non pointer loading use:" - << *val << " user " << *call << "\n"; - } - continue; - } - } - } - - if (F) { - if (UA == UseActivity::AllStores && - (F->getName() == "julia.write_barrier" || - F->getName() == "julia.write_barrier_binding")) - continue; - if (F->getIntrinsicID() == Intrinsic::memcpy || - F->getIntrinsicID() == Intrinsic::memmove) { - - // copies of constant string data do not impact activity. - if (auto cexpr = dyn_cast(call->getArgOperand(1))) { - if (cexpr->getOpcode() == Instruction::GetElementPtr) { - if (auto GV = dyn_cast(cexpr->getOperand(0))) { - if (GV->hasInitializer() && GV->isConstant()) { - if (auto CDA = - dyn_cast(GV->getInitializer())) { - if (CDA->getType()->getElementType()->isIntegerTy(8)) - continue; - } - } - } - } - } - - // Only need to care about loads from - if (UA == UseActivity::OnlyLoads && call->getArgOperand(1) != parent) - continue; - - // Only need to care about store from - if (call->getArgOperand(0) != parent) { - if (UA == UseActivity::OnlyStores) - continue; - else if (UA == UseActivity::OnlyNonPointerStores || - UA == UseActivity::AllStores) { - // todo can change this to query either -1 (all mem) or 0..size - // (if size of copy is const) - if (!TR.query(call->getArgOperand(1))[{-1, -1}] - .isPossiblePointer()) - continue; - } - } - - bool shouldContinue = false; - if (UA != UseActivity::AllStores) - for (int arg = 0; arg < 2; arg++) - if (call->getArgOperand(arg) != parent && - (arg == 0 || (PUA == UseActivity::None))) { - Value *TmpOrig = call->getOperand(arg); - while (1) { - if (AllocaSet.count(TmpOrig)) { - shouldContinue = true; - break; - } - if (isa(TmpOrig)) { - done.insert(std::make_tuple((User *)call, - call->getArgOperand(arg), UA)); - for (const auto a : TmpOrig->users()) { - todo.push_back(std::make_tuple(a, TmpOrig, UA)); - } - AllocaSet.insert(TmpOrig); - shouldContinue = true; - break; - } - if (PUA == UseActivity::None) { - if (auto LI = dyn_cast(TmpOrig)) { - TmpOrig = LI->getPointerOperand(); - continue; - } - if (isAllocationCall(TmpOrig, TLI)) { - done.insert(std::make_tuple( - (User *)call, call->getArgOperand(arg), UA)); - for (const auto a : TmpOrig->users()) { - todo.push_back(std::make_tuple(a, TmpOrig, UA)); - } - AllocaSet.insert(TmpOrig); - shouldContinue = true; - break; - } - } - auto TmpOrig_2 = getBaseObject(TmpOrig); - if (TmpOrig != TmpOrig_2) { - TmpOrig = TmpOrig_2; - continue; - } - break; - } - if (shouldContinue) - break; - } - - if (shouldContinue) - continue; - } - } else if (PUA == UseActivity::None || PUA == UseActivity::OnlyStores) { - // If calling a function derived from an alloca of this value, - // the function is only active if the function stored into - // the allocation is active (all functions not explicitly marked - // inactive), or one of the args to the call is active - Value *operand = call->getCalledOperand(); - - bool toContinue = false; - if (isa(operand)) { - bool legal = true; - - for (unsigned i = 0; i < call->arg_size() + 1; ++i) { - Value *a = call->getOperand(i); - - if (isa(a)) - continue; - - Value *ptr = a; - bool subValue = false; - while (ptr) { - auto TmpOrig2 = getBaseObject(ptr); - if (AllocaSet.count(TmpOrig2)) { - subValue = true; - break; - } - if (isa(TmpOrig2)) { - done.insert(std::make_tuple((User *)call, a, UA)); - for (const auto a : TmpOrig2->users()) { - todo.push_back(std::make_tuple(a, TmpOrig2, UA)); - } - AllocaSet.insert(TmpOrig2); - subValue = true; - break; - } - - if (PUA == UseActivity::None) { - if (isAllocationCall(TmpOrig2, TLI)) { - done.insert(std::make_tuple((User *)call, a, UA)); - for (const auto a : TmpOrig2->users()) { - todo.push_back(std::make_tuple(a, TmpOrig2, UA)); - } - AllocaSet.insert(TmpOrig2); - subValue = true; - break; - } - if (auto L = dyn_cast(TmpOrig2)) { - ptr = L->getPointerOperand(); - } else - ptr = nullptr; - } else - ptr = nullptr; - } - if (subValue) - continue; - legal = false; - break; - } - if (legal) { - toContinue = true; - } - } - if (toContinue) { - if (EnzymePrintActivity) { - llvm::errs() << "Value found indirect call use which must be " - "constant as all stored functions are constant val:" - << *val << " user " << *call << "\n"; - } - for (auto u : call->users()) { - todo.push_back(std::make_tuple(u, a, UseActivity::None)); - } - continue; - } - } - } - - // For an inbound gep, args which are not the pointer being offset - // are not used in an active way by definition. - if (auto gep = dyn_cast(a)) { - if (gep->isInBounds() && gep->getPointerOperand() != parent) - continue; - } - - // If this doesn't write to memory this can only be an active use - // if its return is used in an active way, therefore add this to - // the list of users to analyze - if (auto I = dyn_cast(a)) { - if (notForAnalysis.count(I->getParent())) { - if (EnzymePrintActivity) { - llvm::errs() << "Value found constant unreachable inst use:" << *val - << " user " << *I << "\n"; - } - continue; - } - if (UA != UseActivity::AllStores && ConstantInstructions.count(I)) { - if (I->getType()->isVoidTy() || I->getType()->isTokenTy() || - ConstantValues.count(I)) { - if (EnzymePrintActivity) { - llvm::errs() << "Value found constant inst use:" << *val << " user " - << *I << "\n"; - } - continue; - } - UseActivity NU = UA; - if (UA == UseActivity::OnlyLoads || UA == UseActivity::OnlyStores || - UA == UseActivity::OnlyNonPointerStores) { - if (!isPointerArithmeticInst(I)) - NU = UseActivity::None; - } - - if (EnzymePrintActivity) { - llvm::errs() << "Adding users of value " << *I << " now with sub UA " - << to_string(UA) << "\n"; - } - for (auto u : I->users()) { - todo.push_back(std::make_tuple(u, (Value *)I, NU)); - } - continue; - } - if (!I->mayWriteToMemory() || isa(I)) { - if (TR.query(I)[{-1}].isIntegral()) { - continue; - } - if (UA == UseActivity::OnlyNonPointerStores && - TR.query(I)[{-1}].isFloat()) { - continue; - } - UseActivity NU = UA; - if (UA == UseActivity::OnlyLoads || UA == UseActivity::OnlyStores) { - if (!isPointerArithmeticInst(I)) - NU = UseActivity::None; - } - - for (auto u : I->users()) { - todo.push_back(std::make_tuple(u, (Value *)I, NU)); - } - if (EnzymePrintActivity) { - llvm::errs() << "Adding users2 of value " << *I << " now with sub UA " - << to_string(UA) << "\n"; - } - continue; - } - - if (FoundInst) - *FoundInst = I; - } - - endloop:; - if (EnzymePrintActivity) - llvm::errs() << "Value nonconstant inst (uses):" << *val << " user " << *a - << "\n"; - seenuse = true; - break; - } - - if (EnzymePrintActivity) - llvm::errs() << " " << *val << "\n"; - return !seenuse; -} - -/// Is the value potentially actively returned or stored -bool ActivityAnalyzer::isValueActivelyStoredOrReturned(TypeResults const &TR, - llvm::Value *val, - bool outside) { - // Must be an analyzer only searching down - if (!outside) - assert(directions == DOWN); - - bool ignoreStoresInto = true; - auto key = std::make_pair(ignoreStoresInto, val); - if (StoredOrReturnedCache.find(key) != StoredOrReturnedCache.end()) { - return StoredOrReturnedCache[key]; - } - - if (EnzymePrintActivity) - llvm::errs() << " " << *val - << "\n"; - - StoredOrReturnedCache[key] = false; - - for (const auto a : val->users()) { - if (isa(a)) { - continue; - } - // Loading a value prevents its pointer from being captured - if (isa(a)) { - continue; - } - - if (isa(a)) { - if (ActiveReturns == DIFFE_TYPE::CONSTANT) - continue; - - if (EnzymePrintActivity) - llvm::errs() << " " - << " active from-ret>" << *val << "\n"; - StoredOrReturnedCache[key] = true; - return true; - } - - if (auto call = dyn_cast(a)) { - if (!couldFunctionArgumentCapture(call, val)) { - continue; - } - bool ConstantArg = isFunctionArgumentConstant(call, val); - if (ConstantArg) { - continue; - } - } - - if (auto SI = dyn_cast(a)) { - // If we are being stored into, not storing this value - // this case can be skipped - if (SI->getValueOperand() != val) { - if (!ignoreStoresInto) { - // Storing into active value, return true - if (!isConstantValue(TR, SI->getValueOperand())) { - StoredOrReturnedCache[key] = true; - if (EnzymePrintActivity) - llvm::errs() << " " << *val - << " store into=" << *SI << "\n"; - return true; - } - } - continue; - } else { - // Storing into active memory, return true - if (!isConstantValue(TR, SI->getPointerOperand())) { - StoredOrReturnedCache[key] = true; - if (EnzymePrintActivity) - llvm::errs() << " " << *val << " store=" << *SI - << "\n"; - return true; - } - continue; - } - } - - if (auto inst = dyn_cast(a)) { - if (!inst->mayWriteToMemory() || - (isa(inst) && (AA.onlyReadsMemory(cast(inst)) || - isReadOnly(cast(inst))))) { - // if not written to memory and returning a known constant, this - // cannot be actively returned/stored - if (inst->getParent()->getParent() == TR.getFunction() && - isConstantValue(TR, a)) { - continue; - } - // if not written to memory and returning a value itself - // not actively stored or returned, this is not actively - // stored or returned - if (!isValueActivelyStoredOrReturned(TR, a, outside)) { - continue; - } - } - } - - if (isAllocationCall(a, TLI)) { - // if not written to memory and returning a known constant, this - // cannot be actively returned/stored - if (isConstantValue(TR, a)) { - continue; - } - // if not written to memory and returning a value itself - // not actively stored or returned, this is not actively - // stored or returned - if (!isValueActivelyStoredOrReturned(TR, a, outside)) { - continue; - } - } else if (isDeallocationCall(a, TLI)) { - // freeing memory never counts - continue; - } - // fallback and conservatively assume that if the value is written to - // it is written to active memory - // TODO handle more memory instructions above to be less conservative - - if (EnzymePrintActivity) - llvm::errs() << " " << *val << " - use=" << *a - << "\n"; - return StoredOrReturnedCache[key] = true; - } - - if (EnzymePrintActivity) - llvm::errs() << " " - << *val << "\n"; - return false; -} - -void ActivityAnalyzer::InsertConstantInstruction(TypeResults const &TR, - llvm::Instruction *I) { - ConstantInstructions.insert(I); - auto found = ReEvaluateValueIfInactiveInst.find(I); - if (found == ReEvaluateValueIfInactiveInst.end()) - return; - auto set = std::move(ReEvaluateValueIfInactiveInst[I]); - ReEvaluateValueIfInactiveInst.erase(I); - for (auto toeval : set) { - if (!ActiveValues.count(toeval)) - continue; - ActiveValues.erase(toeval); - if (EnzymePrintActivity) - llvm::errs() << " re-evaluating activity of val " << *toeval - << " due to inst " << *I << "\n"; - isConstantValue(TR, toeval); - } -} - -void ActivityAnalyzer::InsertConstantValue(TypeResults const &TR, - llvm::Value *V) { - ConstantValues.insert(V); - if (InsertConstValueRecursionHandler) { - InsertConstValueRecursionHandler->push_back(V); - return; - } - SmallVector InsertConstValueRecursionHandlerTmp; - InsertConstValueRecursionHandlerTmp.push_back(V); - InsertConstValueRecursionHandler = &InsertConstValueRecursionHandlerTmp; - while (InsertConstValueRecursionHandlerTmp.size()) { - auto V = InsertConstValueRecursionHandlerTmp.back(); - InsertConstValueRecursionHandlerTmp.pop_back(); - auto found = ReEvaluateValueIfInactiveValue.find(V); - if (found != ReEvaluateValueIfInactiveValue.end()) { - auto set = std::move(ReEvaluateValueIfInactiveValue[V]); - ReEvaluateValueIfInactiveValue.erase(V); - for (auto toeval : set) { - if (!ActiveValues.count(toeval)) - continue; - ActiveValues.erase(toeval); - if (EnzymePrintActivity) - llvm::errs() << " re-evaluating activity of val " << *toeval - << " due to value " << *V << "\n"; - isConstantValue(TR, toeval); - } - } - auto found2 = ReEvaluateInstIfInactiveValue.find(V); - if (found2 != ReEvaluateInstIfInactiveValue.end()) { - auto set = std::move(ReEvaluateInstIfInactiveValue[V]); - ReEvaluateInstIfInactiveValue.erase(V); - for (auto toeval : set) { - if (!ActiveInstructions.count(toeval)) - continue; - ActiveInstructions.erase(toeval); - if (EnzymePrintActivity) - llvm::errs() << " re-evaluating activity of inst " << *toeval - << " due to value " << *V << "\n"; - isConstantInstruction(TR, toeval); - } - } - } - InsertConstValueRecursionHandler = nullptr; -} diff --git a/enzyme/Enzyme/ActivityAnalysis.h b/enzyme/Enzyme/ActivityAnalysis.h deleted file mode 100644 index 5945c5b302c1..000000000000 --- a/enzyme/Enzyme/ActivityAnalysis.h +++ /dev/null @@ -1,280 +0,0 @@ -//===- ActivityAnalysis.h - Declaration of Activity Analysis -----------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the declaration of Activity Analysis -- an AD-specific -// analysis that deduces if a given instruction or value can impact the -// calculation of a derivative. This file consists of two mutually recurive -// functions that compute this for values and instructions, respectively. -// -//===----------------------------------------------------------------------===// -#ifndef ENZYME_ACTIVE_VAR_H -#define ENZYME_ACTIVE_VAR_H 1 - -#include -#include - -#include -#define private public -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#undef private - -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -#include "llvm/Support/CommandLine.h" - -#include "llvm/ADT/StringMap.h" -#include "llvm/IR/InstVisitor.h" - -#include "TypeAnalysis/TypeAnalysis.h" -#include "Utils.h" - -extern "C" { -extern llvm::cl::opt EnzymePrintActivity; -extern llvm::cl::opt EnzymeNonmarkedGlobalsInactive; -extern llvm::cl::opt EnzymeGlobalActivity; -extern llvm::cl::opt EnzymeEmptyFnInactive; -extern llvm::cl::opt EnzymeEnableRecursiveHypotheses; -} - -class PreProcessCache; - -// A map of MPI comm allocators (otherwise inactive) to the -// argument of the Comm* they allocate into. -extern const llvm::StringMap MPIInactiveCommAllocators; - -/// Helper class to analyze the differential activity -class ActivityAnalyzer { - PreProcessCache &PPC; - - /// Aliasing Information - llvm::AAResults &AA; - - // Blocks not to be analyzed - const llvm::SmallPtrSetImpl ¬ForAnalysis; - - /// Library Information - llvm::TargetLibraryInfo &TLI; - -public: - /// Whether the returns of the function being analyzed are active - const DIFFE_TYPE ActiveReturns; - -private: - /// Direction of current analysis - const uint8_t directions; - /// Analyze up based off of operands - static constexpr uint8_t UP = 1; - /// Analyze down based off uses - static constexpr uint8_t DOWN = 2; - - /// Instructions that don't propagate adjoints - /// These instructions could return an active pointer, but - /// do not propagate adjoints themselves - llvm::SmallPtrSet ConstantInstructions; - - /// Instructions that could propagate adjoints - llvm::SmallPtrSet ActiveInstructions; - - /// Values that do not contain derivative information, either - /// directly or as a pointer to - llvm::SmallPtrSet ConstantValues; - - /// Values that may contain derivative information - llvm::SmallPtrSet ActiveValues; - - /// Intermediate pointers which are created by inactive instructions - /// but are marked as active values to inductively determine their - /// activity. - llvm::SmallPtrSet DeducingPointers; - -public: - /// Construct the analyzer from the a previous set of constant and active - /// values and whether returns are active. The all arguments of the functions - /// being analyzed must be in the set of constant and active values, lest an - /// error occur during analysis - ActivityAnalyzer( - PreProcessCache &PPC, llvm::AAResults &AA_, - const llvm::SmallPtrSetImpl ¬ForAnalysis_, - llvm::TargetLibraryInfo &TLI_, - const llvm::SmallPtrSetImpl &ConstantValues, - const llvm::SmallPtrSetImpl &ActiveValues, - DIFFE_TYPE ActiveReturns) - : PPC(PPC), AA(AA_), notForAnalysis(notForAnalysis_), TLI(TLI_), - ActiveReturns(ActiveReturns), directions(UP | DOWN), - ConstantValues(ConstantValues.begin(), ConstantValues.end()), - ActiveValues(ActiveValues.begin(), ActiveValues.end()) { - InsertConstValueRecursionHandler = nullptr; - } - - /// Return whether this instruction is known not to propagate adjoints - /// Note that instructions could return an active pointer, but - /// do not propagate adjoints themselves - bool isConstantInstruction(TypeResults const &TR, llvm::Instruction *inst); - - /// Return whether this values is known not to contain derivative - // information, either directly or as a pointer to - bool isConstantValue(TypeResults const &TR, llvm::Value *val); - -private: - llvm::DenseMap> - ReEvaluateValueIfInactiveInst; - llvm::DenseMap> - ReEvaluateValueIfInactiveValue; - - llvm::DenseMap> - ReEvaluateInstIfInactiveValue; - - void InsertConstantInstruction(TypeResults const &TR, llvm::Instruction *I); - llvm::SmallVector *InsertConstValueRecursionHandler; - void InsertConstantValue(TypeResults const &TR, llvm::Value *V); - - /// Create a new analyzer starting from an existing Analyzer - /// This is used to perform inductive assumptions - ActivityAnalyzer(ActivityAnalyzer &Other, uint8_t directions) - : PPC(Other.PPC), AA(Other.AA), notForAnalysis(Other.notForAnalysis), - TLI(Other.TLI), ActiveReturns(Other.ActiveReturns), - directions(directions), - ConstantInstructions(Other.ConstantInstructions), - ActiveInstructions(Other.ActiveInstructions), - ConstantValues(Other.ConstantValues), ActiveValues(Other.ActiveValues), - DeducingPointers(Other.DeducingPointers) { - assert(directions != 0); - assert((directions & Other.directions) == directions); - assert((directions & Other.directions) != 0); - InsertConstValueRecursionHandler = nullptr; - } - - /// Import known constants from an existing analyzer - void insertConstantsFrom(TypeResults const &TR, - ActivityAnalyzer &Hypothesis) { - for (auto I : Hypothesis.ConstantInstructions) { - InsertConstantInstruction(TR, I); - } - for (auto V : Hypothesis.ConstantValues) { - InsertConstantValue(TR, V); - } - } - - /// Import known data from an existing analyzer - void insertAllFrom(TypeResults const &TR, ActivityAnalyzer &Hypothesis, - llvm::Value *Orig, llvm::Value *Orig2 = nullptr) { - insertConstantsFrom(TR, Hypothesis); - for (auto I : Hypothesis.ActiveInstructions) { - bool inserted = ActiveInstructions.insert(I).second; - if (inserted && directions == 3 && EnzymeEnableRecursiveHypotheses) { - ReEvaluateInstIfInactiveValue[Orig].insert(I); - if (Orig2 && Orig2 != Orig) - ReEvaluateInstIfInactiveValue[Orig2].insert(I); - } - } - for (auto V : Hypothesis.ActiveValues) { - bool inserted = ActiveValues.insert(V).second; - if (inserted && directions == 3 && EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[Orig].insert(V); - if (Orig2 && Orig2 != Orig) - ReEvaluateValueIfInactiveValue[Orig2].insert(V); - } - } - - for (auto &pair : Hypothesis.ReEvaluateValueIfInactiveInst) { - ReEvaluateValueIfInactiveValue[pair.first].insert(pair.second.begin(), - pair.second.end()); - if (ConstantInstructions.count(pair.first)) { - InsertConstantInstruction(TR, pair.first); - } - } - for (auto &pair : Hypothesis.ReEvaluateInstIfInactiveValue) { - ReEvaluateInstIfInactiveValue[pair.first].insert(pair.second.begin(), - pair.second.end()); - if (ConstantValues.count(pair.first)) { - InsertConstantValue(TR, pair.first); - } - } - for (auto &pair : Hypothesis.ReEvaluateValueIfInactiveValue) { - ReEvaluateValueIfInactiveValue[pair.first].insert(pair.second.begin(), - pair.second.end()); - if (ConstantValues.count(pair.first)) { - InsertConstantValue(TR, pair.first); - } - } - } - - /// Is the use of value val as an argument of call CI known to be inactive - bool isFunctionArgumentConstant(llvm::CallInst *CI, llvm::Value *val); - - /// Is the instruction guaranteed to be inactive because of its operands. - /// \p considerValue specifies that we ask whether the returned value, rather - /// than the instruction itself is active. - bool isInstructionInactiveFromOrigin(TypeResults const &TR, llvm::Value *val, - bool considerValue); - -public: - enum class UseActivity { - // No Additional use activity info - None = 0, - - // Recursively consider loads to identify a potential active load. - // Intermediate stores into local allocations will be looked through. - OnlyLoads = 1, - - // Only consider active stores into - OnlyStores = 2, - - // Only consider active stores and pointer-style loads - OnlyNonPointerStores = 3, - - // Only consider any (active or not) stores into - AllStores = 4 - }; - /// Is the value free of any active uses - bool isValueInactiveFromUsers(TypeResults const &TR, llvm::Value *val, - UseActivity UA, - llvm::Instruction **FoundInst = nullptr); - - /// Is the value potentially actively returned or stored - bool isValueActivelyStoredOrReturned(TypeResults const &TR, llvm::Value *val, - bool outside = false); - -private: - /// StoredOrReturnedCache acts as an inductive cache of results for - /// isValueActivelyStoredOrReturned - std::map, bool> StoredOrReturnedCache; -}; - -constexpr inline const char *to_string(ActivityAnalyzer::UseActivity UA) { - switch (UA) { - case ActivityAnalyzer::UseActivity::None: - return "None"; - case ActivityAnalyzer::UseActivity::OnlyLoads: - return "OnlyLoads"; - case ActivityAnalyzer::UseActivity::OnlyStores: - return "OnlyStores"; - case ActivityAnalyzer::UseActivity::OnlyNonPointerStores: - return "OnlyNonPointerStores"; - case ActivityAnalyzer::UseActivity::AllStores: - return "AllStores"; - } - return ""; -} -#endif diff --git a/enzyme/Enzyme/ActivityAnalysisPrinter.cpp b/enzyme/Enzyme/ActivityAnalysisPrinter.cpp deleted file mode 100644 index 9f034f256bc3..000000000000 --- a/enzyme/Enzyme/ActivityAnalysisPrinter.cpp +++ /dev/null @@ -1,214 +0,0 @@ -// ActivityAnalysisPrinter.cpp - Printer utility pass for Activity Analysis =// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains a utility LLVM pass for printing derived Activity Analysis -// results of a given function. -// -//===----------------------------------------------------------------------===// -#include - -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" - -#include "llvm/ADT/SmallVector.h" - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Metadata.h" - -#include "llvm/Support/Debug.h" -#include "llvm/Transforms/Scalar.h" - -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/ScalarEvolution.h" - -#include "llvm/Support/CommandLine.h" - -#include "ActivityAnalysis.h" -#include "ActivityAnalysisPrinter.h" -#include "FunctionUtils.h" -#include "TypeAnalysis/TypeAnalysis.h" -#include "Utils.h" - -using namespace llvm; -#ifdef DEBUG_TYPE -#undef DEBUG_TYPE -#endif -#define DEBUG_TYPE "activity-analysis-results" - -/// Function TypeAnalysis will be starting its run from -static llvm::cl::opt - FunctionToAnalyze("activity-analysis-func", cl::init(""), cl::Hidden, - cl::desc("Which function to analyze/print")); - -static llvm::cl::opt - InactiveArgs("activity-analysis-inactive-args", cl::init(false), cl::Hidden, - cl::desc("Whether all args are inactive")); - -static llvm::cl::opt - DuplicatedRet("activity-analysis-duplicated-ret", cl::init(false), - cl::Hidden, cl::desc("Whether the return is duplicated")); -namespace { - -bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) { - if (F.getName() != FunctionToAnalyze) - return /*changed*/ false; - - FnTypeInfo type_args(&F); - for (auto &a : type_args.Function->args()) { - TypeTree dt; - if (a.getType()->isFPOrFPVectorTy()) { - dt = ConcreteType(a.getType()->getScalarType()); - } else if (a.getType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR < 17 - if (a.getContext().supportsTypedPointers()) { - auto et = a.getType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); - } - } -#endif - } else if (a.getType()->isIntOrIntVectorTy()) { - dt = ConcreteType(BaseType::Integer); - } - type_args.Arguments.insert( - std::pair(&a, dt.Only(-1, nullptr))); - // TODO note that here we do NOT propagate constants in type info (and - // should consider whether we should) - type_args.KnownValues.insert( - std::pair>(&a, {})); - } - - TypeTree dt; - if (F.getReturnType()->isFPOrFPVectorTy()) { - dt = ConcreteType(F.getReturnType()->getScalarType()); - } else if (F.getReturnType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR < 17 - if (F.getContext().supportsTypedPointers()) { - auto et = F.getReturnType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); - } - } -#endif - } else if (F.getReturnType()->isIntOrIntVectorTy()) { - dt = ConcreteType(BaseType::Integer); - } - type_args.Return = dt.Only(-1, nullptr); - - PreProcessCache PPC; - TypeAnalysis TA(PPC.FAM); - TypeResults TR = TA.analyzeFunction(type_args); - - llvm::SmallPtrSet ConstantValues; - llvm::SmallPtrSet ActiveValues; - for (auto &a : type_args.Function->args()) { - if (InactiveArgs) { - ConstantValues.insert(&a); - } else if (a.getType()->isIntOrIntVectorTy()) { - ConstantValues.insert(&a); - } else { - ActiveValues.insert(&a); - } - } - - DIFFE_TYPE ActiveReturns = F.getReturnType()->isFPOrFPVectorTy() - ? DIFFE_TYPE::OUT_DIFF - : DIFFE_TYPE::CONSTANT; - if (DuplicatedRet) - ActiveReturns = DIFFE_TYPE::DUP_ARG; - SmallPtrSet notForAnalysis(getGuaranteedUnreachable(&F)); - ActivityAnalyzer ATA(PPC, PPC.FAM.getResult(F), notForAnalysis, - TLI, ConstantValues, ActiveValues, ActiveReturns); - - for (auto &a : F.args()) { - ATA.isConstantValue(TR, &a); - llvm::errs().flush(); - } - for (auto &BB : F) { - for (auto &I : BB) { - ATA.isConstantInstruction(TR, &I); - ATA.isConstantValue(TR, &I); - llvm::errs().flush(); - } - } - - for (auto &a : F.args()) { - bool icv = ATA.isConstantValue(TR, &a); - llvm::errs().flush(); - llvm::outs() << a << ": icv:" << icv << "\n"; - llvm::outs().flush(); - } - for (auto &BB : F) { - llvm::outs() << BB.getName() << "\n"; - for (auto &I : BB) { - bool ici = ATA.isConstantInstruction(TR, &I); - bool icv = ATA.isConstantValue(TR, &I); - llvm::errs().flush(); - llvm::outs() << I << ": icv:" << icv << " ici:" << ici << "\n"; - llvm::outs().flush(); - } - } - return /*changed*/ false; -} - -class ActivityAnalysisPrinter final : public FunctionPass { -public: - static char ID; - ActivityAnalysisPrinter() : FunctionPass(ID) {} - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired(); - } - - bool runOnFunction(Function &F) override { - - auto &TLI = getAnalysis().getTLI(F); - - return printActivityAnalysis(F, TLI); - } -}; - -} // namespace - -char ActivityAnalysisPrinter::ID = 0; - -static RegisterPass - X("print-activity-analysis", "Print Activity Analysis Results"); - -ActivityAnalysisPrinterNewPM::Result -ActivityAnalysisPrinterNewPM::run(llvm::Function &F, - llvm::FunctionAnalysisManager &FAM) { - bool changed = false; - changed = printActivityAnalysis(F, FAM.getResult(F)); - return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); -} -llvm::AnalysisKey ActivityAnalysisPrinterNewPM::Key; diff --git a/enzyme/Enzyme/ActivityAnalysisPrinter.h b/enzyme/Enzyme/ActivityAnalysisPrinter.h deleted file mode 100644 index 0855dd32a43c..000000000000 --- a/enzyme/Enzyme/ActivityAnalysisPrinter.h +++ /dev/null @@ -1,54 +0,0 @@ -//=- ActivityAnalysisPrinter.h - Printer utility pass for Activity Analysis =// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains a utility LLVM pass for printing derived Activity Analysis -// results of a given function. -// -//===----------------------------------------------------------------------===// - -#ifndef ENZYME_ACTIVITY_ANALYSIS_PRINTER_H -#define ENZYME_ACTIVITY_ANALYSIS_PRINTER_H - -#include - -#include "llvm/IR/PassManager.h" -#include "llvm/Passes/PassPlugin.h" - -namespace llvm { -class FunctionPass; -} - -class ActivityAnalysisPrinterNewPM final - : public llvm::AnalysisInfoMixin { - friend struct llvm::AnalysisInfoMixin; - -private: - static llvm::AnalysisKey Key; - -public: - using Result = llvm::PreservedAnalyses; - ActivityAnalysisPrinterNewPM() {} - - Result run(llvm::Function &M, llvm::FunctionAnalysisManager &MAM); - - static bool isRequired() { return true; } -}; - -#endif // ENZYME_ACTIVITY_ANALYSIS_PRINTER_H diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h deleted file mode 100644 index 171a93547b17..000000000000 --- a/enzyme/Enzyme/AdjointGenerator.h +++ /dev/null @@ -1,6523 +0,0 @@ -//===- AdjointGenerator.h - Implementation of Adjoint's of instructions --===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains an instruction visitor AdjointGenerator that generates -// the corresponding augmented forward pass code, and adjoints for all -// LLVM instructions. -// -//===----------------------------------------------------------------------===// - -#ifndef ENZYME_ADJOINT_GENERATOR_H -#define ENZYME_ADJOINT_GENERATOR_H - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IntrinsicsX86.h" -#include "llvm/IR/Value.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" - -#include "DiffeGradientUtils.h" -#include "DifferentialUseAnalysis.h" -#include "EnzymeLogic.h" -#include "FunctionUtils.h" -#include "GradientUtils.h" -#include "LibraryFuncs.h" -#include "TraceUtils.h" -#include "TypeAnalysis/TBAA.h" - -#define DEBUG_TYPE "enzyme" - -// Helper instruction visitor that generates adjoints -class AdjointGenerator : public llvm::InstVisitor { -private: - // Type of code being generated (forward, reverse, or both) - const DerivativeMode Mode; - - GradientUtils *const gutils; - llvm::ArrayRef constant_args; - DIFFE_TYPE retType; - TypeResults &TR = gutils->TR; - std::function &)> - getIndex; - const std::map>> - overwritten_args_map; - const AugmentedReturn *augmentedReturn; - const std::map *replacedReturns; - - const llvm::SmallPtrSetImpl &unnecessaryValues; - const llvm::SmallPtrSetImpl - &unnecessaryInstructions; - const llvm::SmallPtrSetImpl &unnecessaryStores; - const llvm::SmallPtrSetImpl &oldUnreachable; - -public: - AdjointGenerator( - DerivativeMode Mode, GradientUtils *gutils, - llvm::ArrayRef constant_args, DIFFE_TYPE retType, - std::function &)> - getIndex, - const std::map>> - overwritten_args_map, - const AugmentedReturn *augmentedReturn, - const std::map *replacedReturns, - const llvm::SmallPtrSetImpl &unnecessaryValues, - const llvm::SmallPtrSetImpl - &unnecessaryInstructions, - const llvm::SmallPtrSetImpl &unnecessaryStores, - const llvm::SmallPtrSetImpl &oldUnreachable) - : Mode(Mode), gutils(gutils), constant_args(constant_args), - retType(retType), getIndex(getIndex), - overwritten_args_map(overwritten_args_map), - augmentedReturn(augmentedReturn), replacedReturns(replacedReturns), - unnecessaryValues(unnecessaryValues), - unnecessaryInstructions(unnecessaryInstructions), - unnecessaryStores(unnecessaryStores), oldUnreachable(oldUnreachable) { - using namespace llvm; - - assert(TR.getFunction() == gutils->oldFunc); - for (auto &pair : TR.analyzer->analysis) { - if (auto in = dyn_cast(pair.first)) { - if (in->getParent()->getParent() != gutils->oldFunc) { - llvm::errs() << "inf: " << *in->getParent()->getParent() << "\n"; - llvm::errs() << "gutils->oldFunc: " << *gutils->oldFunc << "\n"; - llvm::errs() << "in: " << *in << "\n"; - } - assert(in->getParent()->getParent() == gutils->oldFunc); - } - } - } - - void eraseIfUnused(llvm::Instruction &I, bool erase = true, - bool check = true) { - using namespace llvm; - - bool used = - unnecessaryInstructions.find(&I) == unnecessaryInstructions.end(); - if (!used) { - // if decided to cache a value, preserve it here for later - // replacement in EnzymeLogic - auto found = gutils->knownRecomputeHeuristic.find(&I); - if (found != gutils->knownRecomputeHeuristic.end() && !found->second) - used = true; - } - auto iload = gutils->getNewFromOriginal((llvm::Value *)&I); - if (used && check) - return; - - if (auto newi = dyn_cast(iload)) - gutils->eraseWithPlaceholder(newi, &I, "_replacementA", erase); - } - - llvm::Value *MPI_TYPE_SIZE(llvm::Value *DT, llvm::IRBuilder<> &B, - llvm::Type *intType) { - using namespace llvm; - - if (DT->getType()->isIntegerTy()) - DT = B.CreateIntToPtr(DT, getInt8PtrTy(DT->getContext())); - - if (Constant *C = dyn_cast(DT)) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { - return ConstantInt::get(intType, 8, false); - } else if (GV->getName() == "ompi_mpi_float") { - return ConstantInt::get(intType, 4, false); - } - } - } - Type *pargs[] = {getInt8PtrTy(DT->getContext()), - PointerType::getUnqual(intType)}; - auto FT = FunctionType::get(intType, pargs, false); - auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(intType); - llvm::Value *args[] = {DT, alloc}; - if (DT->getType() != pargs[0]) - args[0] = B.CreateBitCast(args[0], pargs[0]); - AttributeList AL; - AL = AL.addParamAttribute(DT->getContext(), 0, - Attribute::AttrKind::ReadOnly); - AL = addFunctionNoCapture(DT->getContext(), AL, 0); - AL = - AL.addParamAttribute(DT->getContext(), 0, Attribute::AttrKind::NoAlias); - AL = - AL.addParamAttribute(DT->getContext(), 0, Attribute::AttrKind::NonNull); - AL = AL.addParamAttribute(DT->getContext(), 1, - Attribute::AttrKind::WriteOnly); - AL = addFunctionNoCapture(DT->getContext(), AL, 1); - AL = - AL.addParamAttribute(DT->getContext(), 1, Attribute::AttrKind::NoAlias); - AL = - AL.addParamAttribute(DT->getContext(), 1, Attribute::AttrKind::NonNull); - AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex, - Attribute::AttrKind::NoUnwind); - AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex, - Attribute::AttrKind::NoFree); - AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex, - Attribute::AttrKind::NoSync); - AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex, - Attribute::AttrKind::WillReturn); - auto CI = B.CreateCall( - B.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( - "MPI_Type_size", FT, AL), - args); -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyAccessesArgMemory(); -#else - CI->addAttributeAtIndex(AttributeList::FunctionIndex, - Attribute::ArgMemOnly); -#endif - return B.CreateLoad(intType, alloc); - } - - // To be double-checked against the functionality needed and the respective - // implementation in Adjoint-MPI - llvm::Value *MPI_COMM_RANK(llvm::Value *comm, llvm::IRBuilder<> &B, - llvm::Type *rankTy) { - using namespace llvm; - - Type *pargs[] = {comm->getType(), PointerType::getUnqual(rankTy)}; - auto FT = FunctionType::get(rankTy, pargs, false); - auto &context = comm->getContext(); - auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(rankTy); - AttributeList AL; - AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::ReadOnly); - AL = addFunctionNoCapture(context, AL, 0); - AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoAlias); - AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NonNull); - AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::WriteOnly); - AL = addFunctionNoCapture(context, AL, 1); - AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoAlias); - AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NonNull); - AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex, - Attribute::AttrKind::NoUnwind); - AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex, - Attribute::AttrKind::NoFree); - AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex, - Attribute::AttrKind::NoSync); - AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex, - Attribute::AttrKind::WillReturn); - llvm::Value *args[] = {comm, alloc}; - B.CreateCall( - B.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( - "MPI_Comm_rank", FT, AL), - args); - return B.CreateLoad(rankTy, alloc); - } - - llvm::Value *MPI_COMM_SIZE(llvm::Value *comm, llvm::IRBuilder<> &B, - llvm::Type *rankTy) { - using namespace llvm; - - Type *pargs[] = {comm->getType(), PointerType::getUnqual(rankTy)}; - auto FT = FunctionType::get(rankTy, pargs, false); - auto &context = comm->getContext(); - auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(rankTy); - AttributeList AL; - AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::ReadOnly); - AL = addFunctionNoCapture(context, AL, 0); - AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoAlias); - AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NonNull); - AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::WriteOnly); - AL = addFunctionNoCapture(context, AL, 1); - AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoAlias); - AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NonNull); - AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex, - Attribute::AttrKind::NoUnwind); - AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex, - Attribute::AttrKind::NoFree); - AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex, - Attribute::AttrKind::NoSync); - AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex, - Attribute::AttrKind::WillReturn); - llvm::Value *args[] = {comm, alloc}; - B.CreateCall( - B.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( - "MPI_Comm_size", FT, AL), - args); - return B.CreateLoad(rankTy, alloc); - } - - void visitInstruction(llvm::Instruction &inst) { - using namespace llvm; - - // TODO explicitly handle all instructions rather than using the catch all - // below - - switch (inst.getOpcode()) { -#include "InstructionDerivatives.inc" - default: - break; - } - - std::string s; - llvm::raw_string_ostream ss(s); - ss << "in Mode: " << to_string(Mode) << "\n"; - ss << "cannot handle unknown instruction\n" << inst; - IRBuilder<> Builder2(&inst); - getForwardBuilder(Builder2); - EmitNoDerivativeError(ss.str(), inst, gutils, Builder2); - if (!gutils->isConstantValue(&inst)) { - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError || - Mode == DerivativeMode::ForwardModeSplit) - setDiffe(&inst, - Constant::getNullValue(gutils->getShadowType(inst.getType())), - Builder2); - } - if (!inst.getType()->isVoidTy()) { - for (auto &U : - make_early_inc_range(gutils->getNewFromOriginal(&inst)->uses())) { - U.set(UndefValue::get(inst.getType())); - } - } - eraseIfUnused(inst, /*erase*/ true, /*check*/ false); - return; - } - - // Common function for falling back to the implementation - // of dual propagation, as available in invertPointerM. - void forwardModeInvertedPointerFallback(llvm::Instruction &I) { - using namespace llvm; - - auto found = gutils->invertedPointers.find(&I); - if (gutils->isConstantValue(&I)) { - assert(found == gutils->invertedPointers.end()); - return; - } - - assert(found != gutils->invertedPointers.end()); - auto placeholder = cast(&*found->second); - gutils->invertedPointers.erase(found); - - if (!DifferentialUseAnalysis::is_value_needed_in_reverse( - gutils, &I, Mode, oldUnreachable)) { - gutils->erase(placeholder); - return; - } - - IRBuilder<> Builder2(&I); - getForwardBuilder(Builder2); - - auto toset = gutils->invertPointerM(&I, Builder2, /*nullShadow*/ true); - - assert(toset != placeholder); - - gutils->replaceAWithB(placeholder, toset); - placeholder->replaceAllUsesWith(toset); - gutils->erase(placeholder); - gutils->invertedPointers.insert( - std::make_pair((const Value *)&I, InvertedPointerVH(gutils, toset))); - return; - } - - void visitAllocaInst(llvm::AllocaInst &I) { - eraseIfUnused(I); - switch (Mode) { - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - forwardModeInvertedPointerFallback(I); - return; - } - default: - return; - } - } - - void visitICmpInst(llvm::ICmpInst &I) { eraseIfUnused(I); } - - void visitFCmpInst(llvm::FCmpInst &I) { eraseIfUnused(I); } - - void visitLoadLike(llvm::Instruction &I, llvm::MaybeAlign alignment, - bool constantval, llvm::Value *mask = nullptr, - llvm::Value *orig_maskInit = nullptr) { - using namespace llvm; - - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 1) / 8; - - assert(Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError || gutils->can_modref_map); - assert(Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError || - gutils->can_modref_map->find(&I) != gutils->can_modref_map->end()); - bool can_modref = (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) - ? false - : gutils->can_modref_map->find(&I)->second; - - constantval |= gutils->isConstantValue(&I); - - Type *type = gutils->getShadowType(I.getType()); - (void)type; - - auto *newi = dyn_cast(gutils->getNewFromOriginal(&I)); - - SmallVector scopeMD = { - gutils->getDerivativeAliasScope(I.getOperand(0), -1)}; - if (auto prev = I.getMetadata(LLVMContext::MD_alias_scope)) { - for (auto &M : cast(prev)->operands()) { - scopeMD.push_back(M); - } - } - auto scope = MDNode::get(I.getContext(), scopeMD); - newi->setMetadata(LLVMContext::MD_alias_scope, scope); - - SmallVector MDs; - for (size_t j = 0; j < gutils->getWidth(); j++) { - MDs.push_back(gutils->getDerivativeAliasScope(I.getOperand(0), j)); - } - if (auto prev = I.getMetadata(LLVMContext::MD_noalias)) { - for (auto &M : cast(prev)->operands()) { - MDs.push_back(M); - } - } - auto noscope = MDNode::get(I.getContext(), MDs); - newi->setMetadata(LLVMContext::MD_noalias, noscope); - - auto vd = TR.query(&I); - - IRBuilder<> BuilderZ(newi); - if (!vd.isKnown()) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of load " << I; - auto ET = I.getType(); - if (looseTypeAnalysis || true) { - vd = defaultTypeTreeForLLVM(ET, &I); - ss << ", assumed " << vd.str() << "\n"; - EmitWarning("CannotDeduceType", I, ss.str()); - goto known; - } - EmitNoTypeError(str, I, gutils, BuilderZ); - known:; - } - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError || - Mode == DerivativeMode::ForwardModeSplit) { - if (!constantval) { - auto found = gutils->invertedPointers.find(&I); - assert(found != gutils->invertedPointers.end()); - Instruction *placeholder = cast(&*found->second); - assert(placeholder->getType() == type); - gutils->invertedPointers.erase(found); - - // only make shadow where caching needed - if (!DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Shadow>(gutils, &I, Mode, oldUnreachable)) { - gutils->erase(placeholder); - return; - } - - if (can_modref) { - if (vd[{-1}].isPossiblePointer()) { - Value *newip = gutils->cacheForReverse( - BuilderZ, placeholder, - getIndex(&I, CacheType::Shadow, BuilderZ)); - assert(newip->getType() == type); - gutils->invertedPointers.insert(std::make_pair( - (const Value *)&I, InvertedPointerVH(gutils, newip))); - } else { - gutils->erase(placeholder); - } - } else { - Value *newip = gutils->invertPointerM(&I, BuilderZ); - if (gutils->runtimeActivity && vd[{-1}].isFloat()) { - // TODO handle mask - assert(!mask); - - auto rule = [&](Value *inop, Value *newip) -> Value * { - Value *shadow = BuilderZ.CreateICmpNE( - gutils->getNewFromOriginal(I.getOperand(0)), inop); - newip = CreateSelect(BuilderZ, shadow, newip, - Constant::getNullValue(newip->getType())); - return newip; - }; - newip = applyChainRule( - I.getType(), BuilderZ, rule, - gutils->invertPointerM(I.getOperand(0), BuilderZ), newip); - } - assert(newip->getType() == type); - placeholder->replaceAllUsesWith(newip); - gutils->erase(placeholder); - gutils->invertedPointers.erase(&I); - gutils->invertedPointers.insert(std::make_pair( - (const Value *)&I, InvertedPointerVH(gutils, newip))); - } - } - return; - } - - //! Store inverted pointer loads that need to be cached for use in reverse - //! pass - if (vd[{-1}].isPossiblePointer()) { - auto found = gutils->invertedPointers.find(&I); - if (found != gutils->invertedPointers.end()) { - Instruction *placeholder = cast(&*found->second); - assert(placeholder->getType() == type); - gutils->invertedPointers.erase(found); - - if (!constantval) { - Value *newip = nullptr; - - // TODO: In the case of fwd mode this should be true if the loaded - // value itself is used as a pointer. - bool needShadow = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Shadow>(gutils, &I, Mode, oldUnreachable); - - switch (Mode) { - - case DerivativeMode::ReverseModePrimal: - case DerivativeMode::ReverseModeCombined: { - if (!needShadow) { - gutils->erase(placeholder); - } else { - newip = gutils->invertPointerM(&I, BuilderZ); - assert(newip->getType() == type); - if (Mode == DerivativeMode::ReverseModePrimal && can_modref && - DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Shadow>(gutils, &I, - DerivativeMode::ReverseModeGradient, - oldUnreachable)) { - gutils->cacheForReverse( - BuilderZ, newip, getIndex(&I, CacheType::Shadow, BuilderZ)); - } - placeholder->replaceAllUsesWith(newip); - gutils->erase(placeholder); - gutils->invertedPointers.insert(std::make_pair( - (const Value *)&I, InvertedPointerVH(gutils, newip))); - } - break; - } - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - assert(0 && "impossible branch"); - return; - } - case DerivativeMode::ReverseModeGradient: { - if (!needShadow) { - gutils->erase(placeholder); - } else { - // only make shadow where caching needed - if (can_modref) { - newip = gutils->cacheForReverse( - BuilderZ, placeholder, - getIndex(&I, CacheType::Shadow, BuilderZ)); - assert(newip->getType() == type); - gutils->invertedPointers.insert(std::make_pair( - (const Value *)&I, InvertedPointerVH(gutils, newip))); - } else { - newip = gutils->invertPointerM(&I, BuilderZ); - assert(newip->getType() == type); - placeholder->replaceAllUsesWith(newip); - gutils->erase(placeholder); - gutils->invertedPointers.insert(std::make_pair( - (const Value *)&I, InvertedPointerVH(gutils, newip))); - } - } - break; - } - } - - } else { - gutils->erase(placeholder); - } - } - } - - Value *inst = newi; - - //! Store loads that need to be cached for use in reverse pass - - // Only cache value here if caching decision isn't precomputed. - // Otherwise caching will be done inside EnzymeLogic.cpp at - // the end of the function jointly. - if (Mode != DerivativeMode::ForwardMode && - Mode != DerivativeMode::ForwardModeError && - !gutils->knownRecomputeHeuristic.count(&I) && can_modref && - !gutils->unnecessaryIntermediates.count(&I)) { - // we can pre initialize all the knownRecomputeHeuristic values to false - // (not needing) as we may assume that minCutCache already preserves - // everything it requires. - std::map Seen; - bool primalNeededInReverse = false; - for (auto pair : gutils->knownRecomputeHeuristic) - if (!pair.second) { - Seen[UsageKey(pair.first, QueryType::Primal)] = false; - if (pair.first == &I) - primalNeededInReverse = true; - } - auto cacheMode = (Mode == DerivativeMode::ReverseModePrimal) - ? DerivativeMode::ReverseModeGradient - : Mode; - primalNeededInReverse |= - DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &I, cacheMode, Seen, oldUnreachable); - if (primalNeededInReverse) { - inst = gutils->cacheForReverse(BuilderZ, newi, - getIndex(&I, CacheType::Self, BuilderZ)); - (void)inst; - assert(inst->getType() == type); - - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit || - Mode == DerivativeMode::ForwardModeError) { - assert(inst != newi); - } else { - assert(inst == newi); - } - } - } - - if (Mode == DerivativeMode::ReverseModePrimal) - return; - - if (constantval) - return; - - if (nonmarkedglobals_inactiveloads) { - // Assume that non enzyme_shadow globals are inactive - // If we ever store to a global variable, we will error if it doesn't - // have a shadow This allows functions who only read global memory to - // have their derivative computed Note that this is too aggressive for - // general programs as if the global aliases with an argument something - // that is written to, then we will have a logical error - if (auto arg = dyn_cast(I.getOperand(0))) { - if (!hasMetadata(arg, "enzyme_shadow")) { - return; - } - } - } - - // Only propagate if instruction is active. The value can be active and not - // the instruction if the value is a potential pointer. This may not be - // caught by type analysis is the result does not have a known type. - if (!gutils->isConstantInstruction(&I)) { - switch (Mode) { - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - assert(0 && "impossible branch"); - return; - } - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: { - - IRBuilder<> Builder2(&I); - getReverseBuilder(Builder2); - - Value *prediff = nullptr; - - for (ssize_t i = -1; i < (ssize_t)LoadSize; ++i) { - if (vd[{(int)i}].isFloat()) { - prediff = diffe(&I, Builder2); - break; - } - } - - Value *premask = nullptr; - - if (prediff && mask) { - premask = lookup(mask, Builder2); - } - - if (prediff) - ((DiffeGradientUtils *)gutils) - ->addToInvertedPtrDiffe(&I, &I, vd, LoadSize, I.getOperand(0), - prediff, Builder2, alignment, premask); - - unsigned start = 0; - unsigned size = LoadSize; - - while (1) { - unsigned nextStart = size; - - auto dt = vd[{-1}]; - for (size_t i = start; i < size; ++i) { - bool Legal = true; - dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal); - if (!Legal) { - nextStart = i; - break; - } - } - if (!dt.isKnown()) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of load " << I; - ss << " vd:" << vd.str() << " start:" << start << " size: " << size - << " dt:" << dt.str() << "\n"; - EmitNoTypeError(str, I, gutils, BuilderZ); - continue; - } - assert(dt.isKnown()); - - if (Type *isfloat = dt.isFloat()) { - if (premask && !gutils->isConstantValue(orig_maskInit)) { - // Masked partial type is unhanled. - if (premask) - assert(start == 0 && nextStart == LoadSize); - addToDiffe(orig_maskInit, prediff, Builder2, isfloat, - Builder2.CreateNot(premask)); - } - } - - if (nextStart == size) - break; - start = nextStart; - } - break; - } - case DerivativeMode::ReverseModePrimal: - break; - } - } - } - - void visitLoadInst(llvm::LoadInst &LI) { - using namespace llvm; - - // If a load of an omp init argument, don't cache for reverse - // and don't do any adjoint propagation (assumed integral) - for (auto U : LI.getPointerOperand()->users()) { - if (auto CI = dyn_cast(U)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "__kmpc_for_static_init_4" || - F->getName() == "__kmpc_for_static_init_4u" || - F->getName() == "__kmpc_for_static_init_8" || - F->getName() == "__kmpc_for_static_init_8u") { - eraseIfUnused(LI); - return; - } - } - } - } - - auto alignment = LI.getAlign(); - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - - bool constantval = parseTBAA(LI, DL, nullptr)[{-1}].isIntegral(); - visitLoadLike(LI, alignment, constantval); - eraseIfUnused(LI); - } - - void visitAtomicRMWInst(llvm::AtomicRMWInst &I) { - using namespace llvm; - - if (gutils->isConstantInstruction(&I) && gutils->isConstantValue(&I)) { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit) { - eraseIfUnused(I, /*erase*/ true, /*check*/ false); - } else { - eraseIfUnused(I); - } - return; - } - - IRBuilder<> BuilderZ(&I); - getForwardBuilder(BuilderZ); - - switch (I.getOperation()) { - case AtomicRMWInst::FAdd: - case AtomicRMWInst::FSub: { - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError || - Mode == DerivativeMode::ForwardModeSplit) { - auto rule = [&](Value *ptr, Value *dif) -> Value * { - if (dif == nullptr) - dif = Constant::getNullValue(I.getType()); - if (!gutils->isConstantInstruction(&I)) { - assert(ptr); - AtomicRMWInst *rmw = nullptr; - rmw = BuilderZ.CreateAtomicRMW(I.getOperation(), ptr, dif, - I.getAlign(), I.getOrdering(), - I.getSyncScopeID()); - rmw->setVolatile(I.isVolatile()); - if (gutils->isConstantValue(&I)) - return Constant::getNullValue(dif->getType()); - else - return rmw; - } else { - assert(gutils->isConstantValue(&I)); - return Constant::getNullValue(dif->getType()); - } - }; - - Value *diff = applyChainRule( - I.getType(), BuilderZ, rule, - gutils->isConstantValue(I.getPointerOperand()) - ? nullptr - : gutils->invertPointerM(I.getPointerOperand(), BuilderZ), - gutils->isConstantValue(I.getValOperand()) - ? nullptr - : gutils->invertPointerM(I.getValOperand(), BuilderZ)); - if (!gutils->isConstantValue(&I)) - setDiffe(&I, diff, BuilderZ); - return; - } - if (Mode == DerivativeMode::ReverseModePrimal) { - eraseIfUnused(I); - return; - } - if ((Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModeGradient) && - gutils->isConstantValue(&I)) { - if (!gutils->isConstantValue(I.getValOperand())) { - assert(!gutils->isConstantValue(I.getPointerOperand())); - IRBuilder<> Builder2(&I); - getReverseBuilder(Builder2); - Value *ip = gutils->invertPointerM(I.getPointerOperand(), Builder2); - ip = lookup(ip, Builder2); - auto order = I.getOrdering(); - if (order == AtomicOrdering::Release) - order = AtomicOrdering::Monotonic; - else if (order == AtomicOrdering::AcquireRelease) - order = AtomicOrdering::Acquire; - - auto rule = [&](Value *ip) -> Value * { - LoadInst *dif1 = - Builder2.CreateLoad(I.getType(), ip, I.isVolatile()); - - dif1->setAlignment(I.getAlign()); - dif1->setOrdering(order); - dif1->setSyncScopeID(I.getSyncScopeID()); - return dif1; - }; - Value *diff = applyChainRule(I.getType(), Builder2, rule, ip); - - addToDiffe(I.getValOperand(), diff, Builder2, - I.getValOperand()->getType()->getScalarType()); - } - if (Mode == DerivativeMode::ReverseModeGradient) { - eraseIfUnused(I, /*erase*/ true, /*check*/ false); - } else - eraseIfUnused(I); - return; - } - break; - } - default: - break; - } - - if (looseTypeAnalysis) { - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - auto valType = I.getValOperand()->getType(); - auto storeSize = DL.getTypeSizeInBits(valType) / 8; - auto fp = TR.firstPointer(storeSize, I.getPointerOperand(), &I, - /*errifnotfound*/ false, - /*pointerIntSame*/ true); - if (!fp.isKnown() && valType->isIntOrIntVectorTy()) { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeGradient) { - eraseIfUnused(I, /*erase*/ true, /*check*/ false); - } else - eraseIfUnused(I); - return; - } - } - std::string s; - llvm::raw_string_ostream ss(s); - ss << *I.getParent()->getParent() << "\n" << I << "\n"; - ss << " Active atomic inst not yet handled"; - EmitNoDerivativeError(ss.str(), I, gutils, BuilderZ); - if (!gutils->isConstantValue(&I)) { - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError || - Mode == DerivativeMode::ForwardModeSplit) - setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())), - BuilderZ); - } - if (!I.getType()->isVoidTy()) { - for (auto &U : - make_early_inc_range(gutils->getNewFromOriginal(&I)->uses())) { - U.set(UndefValue::get(I.getType())); - } - } - eraseIfUnused(I, /*erase*/ true, /*check*/ false); - return; - } - - void visitStoreInst(llvm::StoreInst &SI) { - using namespace llvm; - - // If a store of an omp init argument, don't delete in reverse - // and don't do any adjoint propagation (assumed integral) - for (auto U : SI.getPointerOperand()->users()) { - if (auto CI = dyn_cast(U)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "__kmpc_for_static_init_4" || - F->getName() == "__kmpc_for_static_init_4u" || - F->getName() == "__kmpc_for_static_init_8" || - F->getName() == "__kmpc_for_static_init_8u") { - return; - } - } - } - } - auto align = SI.getAlign(); - - visitCommonStore(SI, SI.getPointerOperand(), SI.getValueOperand(), align, - SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(), - /*mask=*/nullptr); - - bool forceErase = false; - if (Mode == DerivativeMode::ReverseModeGradient) { - // Since we won't redo the store in the reverse pass, do not - // force the write barrier. - forceErase = true; - for (const auto &pair : gutils->rematerializableAllocations) { - // However, if we are rematerailizing the allocationa and not - // inside the loop level rematerialization, we do still need the - // reverse passes ``fake primal'' store and therefore write barrier - if (pair.second.stores.count(&SI) && - (!pair.second.LI || !pair.second.LI->contains(&SI))) { - forceErase = false; - } - } - } - if (forceErase) - eraseIfUnused(SI, /*erase*/ true, /*check*/ false); - else - eraseIfUnused(SI); - } - - void visitCommonStore(llvm::Instruction &I, llvm::Value *orig_ptr, - llvm::Value *orig_val, llvm::MaybeAlign prevalign, - bool isVolatile, llvm::AtomicOrdering ordering, - llvm::SyncScope::ID syncScope, llvm::Value *mask) { - using namespace llvm; - - Value *val = gutils->getNewFromOriginal(orig_val); - Type *valType = orig_val->getType(); - - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - - if (unnecessaryStores.count(&I)) { - return; - } - - if (gutils->isConstantValue(orig_ptr)) { - return; - } - - SmallVector scopeMD = { - gutils->getDerivativeAliasScope(orig_ptr, -1)}; - SmallVector prevScopes; - if (auto prev = I.getMetadata(LLVMContext::MD_alias_scope)) { - for (auto &M : cast(prev)->operands()) { - scopeMD.push_back(M); - prevScopes.push_back(M); - } - } - auto scope = MDNode::get(I.getContext(), scopeMD); - auto NewI = gutils->getNewFromOriginal(&I); - NewI->setMetadata(LLVMContext::MD_alias_scope, scope); - - SmallVector MDs; - SmallVector prevNoAlias; - for (size_t j = 0; j < gutils->getWidth(); j++) { - MDs.push_back(gutils->getDerivativeAliasScope(orig_ptr, j)); - } - if (auto prev = I.getMetadata(LLVMContext::MD_noalias)) { - for (auto &M : cast(prev)->operands()) { - MDs.push_back(M); - prevNoAlias.push_back(M); - } - } - auto noscope = MDNode::get(I.getContext(), MDs); - NewI->setMetadata(LLVMContext::MD_noalias, noscope); - - bool constantval = gutils->isConstantValue(orig_val) || - parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); - - IRBuilder<> BuilderZ(NewI); - BuilderZ.setFastMathFlags(getFast()); - - // TODO allow recognition of other types that could contain pointers [e.g. - // {void*, void*} or <2 x i64> ] - auto storeSize = (DL.getTypeSizeInBits(valType) + 7) / 8; - - auto vd = TR.query(orig_ptr).Lookup(storeSize, DL); - - if (!vd.isKnown()) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of store " << I; - if (looseTypeAnalysis || true) { - vd = defaultTypeTreeForLLVM(valType, &I); - ss << ", assumed " << vd.str() << "\n"; - EmitWarning("CannotDeduceType", I, ss.str()); - goto known; - } - EmitNoTypeError(str, I, gutils, BuilderZ); - return; - known:; - } - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - - auto dt = vd[{-1}]; - // Only need the full type in forward mode, if storing a constant - // and therefore may need to zero some floats. - if (constantval) - for (size_t i = 0; i < storeSize; ++i) { - bool Legal = true; - dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal); - if (!Legal) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce single type of store " << I << vd.str() - << " size: " << storeSize; - EmitNoTypeError(str, I, gutils, BuilderZ); - return; - } - } - - Value *diff = nullptr; - if (!gutils->runtimeActivity && constantval) { - if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) { - if (!isa(orig_val) && - !isa(orig_val)) { - std::string str; - raw_string_ostream ss(str); - ss << "Mismatched activity for: " << I - << " const val: " << *orig_val; - if (CustomErrorHandler) - diff = unwrap(CustomErrorHandler( - str.c_str(), wrap(&I), ErrorType::MixedActivityError, gutils, - wrap(orig_val), wrap(&BuilderZ))); - else - EmitWarning("MixedActivityError", I, ss.str()); - } - } - } - - // TODO type analyze - if (!diff) { - if (!constantval) - diff = - gutils->invertPointerM(orig_val, BuilderZ, /*nullShadow*/ true); - else if (orig_val->getType()->isPointerTy() || - dt == BaseType::Pointer || dt == BaseType::Integer) - diff = - gutils->invertPointerM(orig_val, BuilderZ, /*nullShadow*/ false); - else - diff = - gutils->invertPointerM(orig_val, BuilderZ, /*nullShadow*/ true); - } - - gutils->setPtrDiffe(&I, orig_ptr, diff, BuilderZ, prevalign, 0, storeSize, - isVolatile, ordering, syncScope, mask, prevNoAlias, - prevScopes); - - return; - } - - unsigned start = 0; - - while (1) { - unsigned nextStart = storeSize; - - auto dt = vd[{-1}]; - for (size_t i = start; i < storeSize; ++i) { - auto nex = vd[{(int)i}]; - if ((nex == BaseType::Anything && dt.isFloat()) || - (dt == BaseType::Anything && nex.isFloat())) { - nextStart = i; - break; - } - bool Legal = true; - dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal); - if (!Legal) { - nextStart = i; - break; - } - } - unsigned size = nextStart - start; - if (!dt.isKnown()) { - - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of store " << I << vd.str() - << " start: " << start << " size: " << size - << " storeSize: " << storeSize; - EmitNoTypeError(str, I, gutils, BuilderZ); - break; - } - - MaybeAlign align; - if (prevalign) { - if (start % prevalign->value() == 0) - align = prevalign; - else - align = Align(1); - } - //! Storing a floating point value - if (Type *FT = dt.isFloat()) { - //! Only need to update the reverse function - switch (Mode) { - case DerivativeMode::ReverseModePrimal: - break; - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: { - IRBuilder<> Builder2(&I); - getReverseBuilder(Builder2); - - if (constantval) { - gutils->setPtrDiffe( - &I, orig_ptr, - Constant::getNullValue(gutils->getShadowType(valType)), - Builder2, align, start, size, isVolatile, ordering, syncScope, - mask, prevNoAlias, prevScopes); - } else { - Value *diff; - Value *maskL = mask; - if (!mask) { - Value *dif1Ptr = - lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2); - - size_t idx = 0; - auto rule = [&](Value *dif1Ptr) { - LoadInst *dif1 = - Builder2.CreateLoad(valType, dif1Ptr, isVolatile); - if (align) - dif1->setAlignment(*align); - dif1->setOrdering(ordering); - dif1->setSyncScopeID(syncScope); - - SmallVector scopeMD = { - gutils->getDerivativeAliasScope(orig_ptr, idx)}; - for (auto M : prevScopes) - scopeMD.push_back(M); - - SmallVector MDs; - for (ssize_t j = -1; j < gutils->getWidth(); j++) { - if (j != (ssize_t)idx) - MDs.push_back(gutils->getDerivativeAliasScope(orig_ptr, j)); - } - for (auto M : prevNoAlias) - MDs.push_back(M); - - dif1->setMetadata(LLVMContext::MD_alias_scope, - MDNode::get(I.getContext(), scopeMD)); - dif1->setMetadata(LLVMContext::MD_noalias, - MDNode::get(I.getContext(), MDs)); - dif1->setMetadata(LLVMContext::MD_tbaa, - I.getMetadata(LLVMContext::MD_tbaa)); - dif1->setMetadata(LLVMContext::MD_tbaa_struct, - I.getMetadata(LLVMContext::MD_tbaa_struct)); - idx++; - return dif1; - }; - - diff = applyChainRule(valType, Builder2, rule, dif1Ptr); - } else { - maskL = lookup(mask, Builder2); - Type *tys[] = {valType, orig_ptr->getType()}; - auto F = getIntrinsicDeclaration(gutils->oldFunc->getParent(), - Intrinsic::masked_load, tys); - Value *alignv = - ConstantInt::get(Type::getInt32Ty(mask->getContext()), - align ? align->value() : 0); - Value *ip = - lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2); - - auto rule = [&](Value *ip) { - Value *args[] = {ip, alignv, maskL, - Constant::getNullValue(valType)}; - diff = Builder2.CreateCall(F, args); - return diff; - }; - - diff = applyChainRule(valType, Builder2, rule, ip); - } - - gutils->setPtrDiffe( - &I, orig_ptr, - Constant::getNullValue(gutils->getShadowType(valType)), - Builder2, align, start, size, isVolatile, ordering, syncScope, - mask, prevNoAlias, prevScopes); - ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_val, diff, Builder2, FT, start, size, {}, - maskL); - } - break; - } - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardModeError: - case DerivativeMode::ForwardMode: { - IRBuilder<> Builder2(&I); - getForwardBuilder(Builder2); - - Type *diffeTy = gutils->getShadowType(valType); - - Value *diff = constantval - ? Constant::getNullValue(diffeTy) - : gutils->invertPointerM(orig_val, Builder2, - /*nullShadow*/ true); - gutils->setPtrDiffe(&I, orig_ptr, diff, Builder2, align, start, size, - isVolatile, ordering, syncScope, mask, - prevNoAlias, prevScopes); - - break; - } - } - - //! Storing an integer or pointer - } else { - //! Only need to update the forward function - - // Don't reproduce mpi null requests - if (constantval) - if (Constant *C = dyn_cast(orig_val)) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_request_null") { - continue; - } - } - } - - bool backwardsShadow = false; - bool forwardsShadow = true; - for (auto pair : gutils->backwardsOnlyShadows) { - if (pair.second.stores.count(&I)) { - backwardsShadow = true; - forwardsShadow = pair.second.primalInitialize; - if (auto inst = dyn_cast(pair.first)) - if (!forwardsShadow && pair.second.LI && - pair.second.LI->contains(inst->getParent())) - backwardsShadow = false; - } - } - - if ((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) || - (Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) || - (Mode == DerivativeMode::ForwardModeSplit && backwardsShadow) || - (Mode == DerivativeMode::ReverseModeCombined && - (forwardsShadow || backwardsShadow)) || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - - Value *valueop = nullptr; - - if (constantval) { - if (!gutils->runtimeActivity) { - if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) { - if (!isa(orig_val) && - !isa(orig_val)) { - std::string str; - raw_string_ostream ss(str); - ss << "Mismatched activity for: " << I - << " const val: " << *orig_val; - if (CustomErrorHandler) - valueop = unwrap(CustomErrorHandler( - str.c_str(), wrap(&I), ErrorType::MixedActivityError, - gutils, wrap(orig_val), wrap(&BuilderZ))); - else - EmitWarning("MixedActivityError", I, ss.str()); - } - } - } - if (!valueop) { - valueop = val; - if (gutils->getWidth() > 1) { - Value *array = - UndefValue::get(gutils->getShadowType(val->getType())); - for (unsigned i = 0; i < gutils->getWidth(); ++i) { - array = BuilderZ.CreateInsertValue(array, val, {i}); - } - valueop = array; - } - } - } else { - valueop = gutils->invertPointerM(orig_val, BuilderZ); - } - gutils->setPtrDiffe(&I, orig_ptr, valueop, BuilderZ, align, start, - size, isVolatile, ordering, syncScope, mask, - prevNoAlias, prevScopes); - } - } - - if (nextStart == storeSize) - break; - start = nextStart; - } - } - - void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { - eraseIfUnused(gep); - switch (Mode) { - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - forwardModeInvertedPointerFallback(gep); - return; - } - default: - return; - } - } - - void visitPHINode(llvm::PHINode &phi) { - eraseIfUnused(phi); - - switch (Mode) { - case DerivativeMode::ReverseModePrimal: - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: { - return; - } - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - forwardModeInvertedPointerFallback(phi); - return; - } - } - } - - void visitCastInst(llvm::CastInst &I) { - using namespace llvm; - - eraseIfUnused(I); - - switch (Mode) { - case DerivativeMode::ReverseModePrimal: { - return; - } - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: { - if (gutils->isConstantInstruction(&I)) - return; - - if (I.getType()->isPointerTy() || - I.getOpcode() == CastInst::CastOps::PtrToInt) - return; - - Value *orig_op0 = I.getOperand(0); - Value *op0 = gutils->getNewFromOriginal(orig_op0); - - IRBuilder<> Builder2(&I); - getReverseBuilder(Builder2); - - if (!gutils->isConstantValue(orig_op0)) { - size_t size = 1; - if (orig_op0->getType()->isSized()) - size = - (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - orig_op0->getType()) + - 7) / - 8; - Type *FT = TR.addingType(size, orig_op0); - if (!FT && looseTypeAnalysis) { - if (auto ET = I.getSrcTy()->getScalarType()) - if (ET->isFPOrFPVectorTy()) { - FT = ET; - EmitWarning("CannotDeduceType", I, - "failed to deduce adding type of cast ", I, - " assumed ", FT, " from src"); - } - } - if (!FT && looseTypeAnalysis) { - if (auto ET = I.getDestTy()->getScalarType()) - if (ET->isFPOrFPVectorTy()) { - FT = ET; - EmitWarning("CannotDeduceType", I, - "failed to deduce adding type of cast ", I, - " assumed ", FT, " from dst"); - } - } - if (!FT) { - if (TR.query(orig_op0)[{-1}] == BaseType::Integer && - TR.query(&I)[{-1}] == BaseType::Integer) - return; - if (looseTypeAnalysis) { - if (auto ET = I.getSrcTy()->getScalarType()) - if (ET->isIntOrIntVectorTy()) { - EmitWarning("CannotDeduceType", I, - "failed to deduce adding type of cast ", I, - " assumed integral from src"); - return; - } - } - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce adding type (cast) of " << I; - EmitNoTypeError(str, I, gutils, Builder2); - } - - if (FT) { - - auto rule = [&](Value *dif) { - if (I.getOpcode() == CastInst::CastOps::FPTrunc || - I.getOpcode() == CastInst::CastOps::FPExt) { - return Builder2.CreateFPCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::BitCast) { - return Builder2.CreateBitCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::Trunc) { - // TODO CHECK THIS - return Builder2.CreateZExt(dif, op0->getType()); - } else { - std::string s; - llvm::raw_string_ostream ss(s); - ss << *I.getParent()->getParent() << "\n"; - ss << "cannot handle above cast " << I << "\n"; - EmitNoDerivativeError(ss.str(), I, gutils, Builder2); - return (llvm::Value *)UndefValue::get(op0->getType()); - } - }; - - Value *dif = diffe(&I, Builder2); - Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif); - - addToDiffe(orig_op0, diff, Builder2, FT); - } - } - - Type *diffTy = gutils->getShadowType(I.getType()); - setDiffe(&I, Constant::getNullValue(diffTy), Builder2); - - break; - } - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - forwardModeInvertedPointerFallback(I); - return; - } - } - } - - void visitSelectInst(llvm::SelectInst &SI) { - eraseIfUnused(SI); - - switch (Mode) { - case DerivativeMode::ReverseModePrimal: - return; - case DerivativeMode::ReverseModeCombined: - case DerivativeMode::ReverseModeGradient: { - if (gutils->isConstantInstruction(&SI)) - return; - if (SI.getType()->isPointerTy()) - return; - createSelectInstAdjoint(SI); - return; - } - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - forwardModeInvertedPointerFallback(SI); - return; - } - } - } - - void createSelectInstAdjoint(llvm::SelectInst &SI) { - using namespace llvm; - - Value *op0 = gutils->getNewFromOriginal(SI.getOperand(0)); - Value *orig_op1 = SI.getOperand(1); - Value *op1 = gutils->getNewFromOriginal(orig_op1); - Value *orig_op2 = SI.getOperand(2); - Value *op2 = gutils->getNewFromOriginal(orig_op2); - - // TODO fix all the reverse builders - IRBuilder<> Builder2(&SI); - getReverseBuilder(Builder2); - - Value *dif1 = nullptr; - Value *dif2 = nullptr; - - size_t size = 1; - if (orig_op1->getType()->isSized()) - size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - orig_op1->getType()) + - 7) / - 8; - // Required loopy phi = [in, BO, BO, ..., BO] - // 1) phi is only used in this B0 - // 2) BO dominates all latches - // 3) phi == B0 whenever not coming from preheader [implies 2] - // 4) [optional but done for ease] one exit to make it easier to - // calculation the product at that point - for (int i = 0; i < 2; i++) - if (auto P0 = dyn_cast(SI.getOperand(i + 1))) { - LoopContext lc; - SmallVector activeUses; - for (auto u : P0->users()) { - if (!gutils->isConstantInstruction(cast(u))) { - activeUses.push_back(cast(u)); - } else if (retType == DIFFE_TYPE::OUT_DIFF && isa(u)) - activeUses.push_back(cast(u)); - } - if (activeUses.size() == 1 && activeUses[0] == &SI && - gutils->getContext(gutils->getNewFromOriginal(P0->getParent()), - lc) && - gutils->getNewFromOriginal(P0->getParent()) == lc.header) { - SmallVector Latches; - gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches); - bool allIncoming = true; - for (auto Latch : Latches) { - if (&SI != P0->getIncomingValueForBlock(Latch)) { - allIncoming = false; - break; - } - } - if (allIncoming && lc.exitBlocks.size() == 1) { - if (!gutils->isConstantValue(SI.getOperand(2 - i))) { - auto addingType = TR.addingType(size, SI.getOperand(2 - i)); - if (addingType || !looseTypeAnalysis) { - auto index = gutils->getOrInsertConditionalIndex( - gutils->getNewFromOriginal(SI.getOperand(0)), lc, i == 1); - IRBuilder<> EB(*lc.exitBlocks.begin()); - getReverseBuilder(EB, /*original=*/false); - Value *inc = lookup(lc.incvar, Builder2); - if (VectorType *VTy = - dyn_cast(SI.getOperand(0)->getType())) { - inc = Builder2.CreateVectorSplat(VTy->getElementCount(), inc); - } - Value *dif = CreateSelect( - Builder2, - Builder2.CreateICmpEQ(gutils->lookupM(index, EB), inc), - diffe(&SI, Builder2), - Constant::getNullValue( - gutils->getShadowType(op1->getType()))); - addToDiffe(SI.getOperand(2 - i), dif, Builder2, addingType); - } - } - return; - } - } - } - - if (!gutils->isConstantValue(orig_op1)) - dif1 = CreateSelect( - Builder2, lookup(op0, Builder2), diffe(&SI, Builder2), - Constant::getNullValue(gutils->getShadowType(op1->getType())), - "diffe" + op1->getName()); - if (!gutils->isConstantValue(orig_op2)) - dif2 = CreateSelect( - Builder2, lookup(op0, Builder2), - Constant::getNullValue(gutils->getShadowType(op2->getType())), - diffe(&SI, Builder2), "diffe" + op2->getName()); - - setDiffe(&SI, Constant::getNullValue(gutils->getShadowType(SI.getType())), - Builder2); - if (dif1) { - Type *addingType = TR.addingType(size, orig_op1); - if (addingType || !looseTypeAnalysis) - addToDiffe(orig_op1, dif1, Builder2, addingType); - else - llvm::errs() << " warning: assuming integral for " << SI << "\n"; - } - if (dif2) { - Type *addingType = TR.addingType(size, orig_op2); - if (addingType || !looseTypeAnalysis) - addToDiffe(orig_op2, dif2, Builder2, addingType); - else - llvm::errs() << " warning: assuming integral for " << SI << "\n"; - } - } - - void visitExtractElementInst(llvm::ExtractElementInst &EEI) { - using namespace llvm; - - eraseIfUnused(EEI); - switch (Mode) { - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - forwardModeInvertedPointerFallback(EEI); - return; - } - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: { - if (gutils->isConstantInstruction(&EEI)) - return; - IRBuilder<> Builder2(&EEI); - getReverseBuilder(Builder2); - - Value *orig_vec = EEI.getVectorOperand(); - - if (!gutils->isConstantValue(orig_vec)) { - - size_t size = 1; - if (EEI.getType()->isSized()) - size = - (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - EEI.getType()) + - 7) / - 8; - auto diff = diffe(&EEI, Builder2); - if (gutils->getWidth() == 1) { - Value *sv[] = {gutils->getNewFromOriginal(EEI.getIndexOperand())}; - ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_vec, diff, Builder2, TR.addingType(size, &EEI), - sv); - } else { - for (size_t i = 0; i < gutils->getWidth(); i++) { - Value *sv[] = {nullptr, - gutils->getNewFromOriginal(EEI.getIndexOperand())}; - sv[0] = ConstantInt::get(sv[1]->getType(), i); - ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_vec, gutils->extractMeta(Builder2, diff, i), - Builder2, TR.addingType(size, &EEI), sv); - } - } - } - setDiffe(&EEI, - Constant::getNullValue(gutils->getShadowType(EEI.getType())), - Builder2); - return; - } - case DerivativeMode::ReverseModePrimal: { - return; - } - } - } - - void visitInsertElementInst(llvm::InsertElementInst &IEI) { - using namespace llvm; - - eraseIfUnused(IEI); - - switch (Mode) { - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - forwardModeInvertedPointerFallback(IEI); - return; - } - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: { - if (gutils->isConstantInstruction(&IEI)) - return; - IRBuilder<> Builder2(&IEI); - getReverseBuilder(Builder2); - - Value *dif1 = diffe(&IEI, Builder2); - - Value *orig_op0 = IEI.getOperand(0); - Value *orig_op1 = IEI.getOperand(1); - Value *op1 = gutils->getNewFromOriginal(orig_op1); - Value *op2 = gutils->getNewFromOriginal(IEI.getOperand(2)); - - size_t size0 = 1; - if (orig_op0->getType()->isSized()) - size0 = - (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - orig_op0->getType()) + - 7) / - 8; - size_t size1 = 1; - if (orig_op1->getType()->isSized()) - size1 = - (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - orig_op1->getType()) + - 7) / - 8; - - if (!gutils->isConstantValue(orig_op0)) { - if (gutils->getWidth() == 1) { - addToDiffe( - orig_op0, - Builder2.CreateInsertElement( - dif1, - Constant::getNullValue(gutils->getShadowType(op1->getType())), - lookup(op2, Builder2)), - Builder2, TR.addingType(size0, orig_op0)); - } else { - for (size_t i = 0; i < gutils->getWidth(); i++) { - Value *sv[] = {ConstantInt::get(op2->getType(), i)}; - ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_op0, - Builder2.CreateInsertElement( - gutils->extractMeta(Builder2, dif1, i), - Constant::getNullValue(op1->getType()), - lookup(op2, Builder2)), - Builder2, TR.addingType(size0, orig_op0), sv); - } - } - } - - if (!gutils->isConstantValue(orig_op1)) { - if (gutils->getWidth() == 1) { - addToDiffe(orig_op1, - Builder2.CreateExtractElement(dif1, lookup(op2, Builder2)), - Builder2, TR.addingType(size1, orig_op1)); - } else { - for (size_t i = 0; i < gutils->getWidth(); i++) { - Value *sv[] = {ConstantInt::get(op2->getType(), i)}; - ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_op1, - Builder2.CreateExtractElement( - gutils->extractMeta(Builder2, dif1, i), - lookup(op2, Builder2)), - Builder2, TR.addingType(size1, orig_op1), sv); - } - } - } - - setDiffe(&IEI, - Constant::getNullValue(gutils->getShadowType(IEI.getType())), - Builder2); - return; - } - case DerivativeMode::ReverseModePrimal: { - return; - } - } - } - - void visitShuffleVectorInst(llvm::ShuffleVectorInst &SVI) { - using namespace llvm; - - eraseIfUnused(SVI); - - switch (Mode) { - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - forwardModeInvertedPointerFallback(SVI); - return; - } - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: { - if (gutils->isConstantInstruction(&SVI)) - return; - IRBuilder<> Builder2(&SVI); - getReverseBuilder(Builder2); - - auto loaded = diffe(&SVI, Builder2); - auto count = - cast(SVI.getOperand(0)->getType())->getElementCount(); - assert(!count.isScalable()); - size_t l1 = count.getKnownMinValue(); - uint64_t instidx = 0; - - for (size_t idx : SVI.getShuffleMask()) { - auto opnum = (idx < l1) ? 0 : 1; - auto opidx = (idx < l1) ? idx : (idx - l1); - - if (!gutils->isConstantValue(SVI.getOperand(opnum))) { - size_t size = 1; - if (SVI.getOperand(opnum)->getType()->isSized()) - size = (gutils->newFunc->getParent() - ->getDataLayout() - .getTypeSizeInBits(SVI.getOperand(opnum)->getType()) + - 7) / - 8; - if (gutils->getWidth() == 1) { - Value *sv[] = { - ConstantInt::get(Type::getInt32Ty(SVI.getContext()), opidx)}; - Value *toadd = Builder2.CreateExtractElement(loaded, instidx); - ((DiffeGradientUtils *)gutils) - ->addToDiffe(SVI.getOperand(opnum), toadd, Builder2, - TR.addingType(size, SVI.getOperand(opnum)), sv); - } else { - for (size_t i = 0; i < gutils->getWidth(); i++) { - Value *sv[] = { - ConstantInt::get(Type::getInt32Ty(SVI.getContext()), i), - ConstantInt::get(Type::getInt32Ty(SVI.getContext()), opidx)}; - Value *toadd = Builder2.CreateExtractElement( - GradientUtils::extractMeta(Builder2, loaded, i), instidx); - ((DiffeGradientUtils *)gutils) - ->addToDiffe(SVI.getOperand(opnum), toadd, Builder2, - TR.addingType(size, SVI.getOperand(opnum)), sv); - } - } - } - ++instidx; - } - setDiffe(&SVI, - Constant::getNullValue(gutils->getShadowType(SVI.getType())), - Builder2); - return; - } - case DerivativeMode::ReverseModePrimal: { - return; - } - } - } - - void visitExtractValueInst(llvm::ExtractValueInst &EVI) { - using namespace llvm; - - eraseIfUnused(EVI); - - if (!gutils->isConstantValue(&EVI) && gutils->isConstantValue(&EVI)) { - llvm::errs() << *gutils->oldFunc->getParent() << "\n"; - llvm::errs() << EVI << "\n"; - llvm_unreachable("Illegal activity for extractvalue"); - } - - switch (Mode) { - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: { - forwardModeInvertedPointerFallback(EVI); - return; - } - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: { - if (gutils->isConstantInstruction(&EVI)) - return; - if (EVI.getType()->isPointerTy()) - return; - IRBuilder<> Builder2(&EVI); - getReverseBuilder(Builder2); - - Value *orig_op0 = EVI.getOperand(0); - - auto prediff = diffe(&EVI, Builder2); - - // todo const - if (!gutils->isConstantValue(orig_op0)) { - SmallVector sv; - for (auto i : EVI.getIndices()) - sv.push_back(ConstantInt::get(Type::getInt32Ty(EVI.getContext()), i)); - size_t storeSize = 1; - if (EVI.getType()->isSized()) - storeSize = - (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - EVI.getType()) + - 7) / - 8; - - unsigned start = 0; - auto vd = TR.query(&EVI); - - while (1) { - unsigned nextStart = storeSize; - - auto dt = vd[{-1}]; - for (size_t i = start; i < storeSize; ++i) { - auto nex = vd[{(int)i}]; - if ((nex == BaseType::Anything && dt.isFloat()) || - (dt == BaseType::Anything && nex.isFloat())) { - nextStart = i; - break; - } - bool Legal = true; - dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal); - if (!Legal) { - nextStart = i; - break; - } - } - unsigned size = nextStart - start; - if (!dt.isKnown()) { - bool found = false; - if (looseTypeAnalysis) { - if (EVI.getType()->isFPOrFPVectorTy()) { - dt = ConcreteType(EVI.getType()->getScalarType()); - found = true; - } else if (EVI.getType()->isIntOrIntVectorTy() || - EVI.getType()->isPointerTy()) { - dt = BaseType::Integer; - found = true; - } - } - if (!found) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of extract " << EVI << vd.str() - << " start: " << start << " size: " << size - << " extractSize: " << storeSize; - EmitNoTypeError(str, EVI, gutils, Builder2); - } - } - if (auto FT = dt.isFloat()) - ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_op0, prediff, Builder2, FT, start, size, sv, - nullptr, /*ignoreFirstSlicesToDiff*/ sv.size()); - - if (nextStart == storeSize) - break; - start = nextStart; - } - } - - setDiffe(&EVI, - Constant::getNullValue(gutils->getShadowType(EVI.getType())), - Builder2); - return; - } - case DerivativeMode::ReverseModePrimal: { - return; - } - } - } - - void visitInsertValueInst(llvm::InsertValueInst &IVI) { - using namespace llvm; - - eraseIfUnused(IVI); - if (gutils->isConstantValue(&IVI)) - return; - - if (Mode == DerivativeMode::ReverseModePrimal) - return; - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeSplit || - Mode == DerivativeMode::ForwardModeError) { - forwardModeInvertedPointerFallback(IVI); - return; - } - - bool hasNonPointer = false; - if (auto st = dyn_cast(IVI.getType())) { - for (unsigned i = 0; i < st->getNumElements(); ++i) { - if (!st->getElementType(i)->isPointerTy()) { - hasNonPointer = true; - } - } - } else if (auto at = dyn_cast(IVI.getType())) { - if (!at->getElementType()->isPointerTy()) { - hasNonPointer = true; - } - } - if (!hasNonPointer) - return; - - bool floatingInsertion = false; - for (InsertValueInst *iv = &IVI;;) { - size_t size0 = 1; - if (iv->getInsertedValueOperand()->getType()->isSized() && - (iv->getInsertedValueOperand()->getType()->isIntOrIntVectorTy() || - iv->getInsertedValueOperand()->getType()->isFPOrFPVectorTy())) - size0 = - (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - iv->getInsertedValueOperand()->getType()) + - 7) / - 8; - auto it = TR.intType(size0, iv->getInsertedValueOperand(), false); - if (it.isFloat() || !it.isKnown()) { - floatingInsertion = true; - break; - } - Value *val = iv->getAggregateOperand(); - if (gutils->isConstantValue(val)) - break; - if (auto dc = dyn_cast(val)) { - iv = dc; - } else { - // unsure where this came from, conservatively assume contains float - floatingInsertion = true; - break; - } - } - - if (!floatingInsertion) - return; - - // TODO handle pointers - // TODO type analysis handle structs - - switch (Mode) { - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: - assert(0 && "should be handled above"); - return; - case DerivativeMode::ReverseModeCombined: - case DerivativeMode::ReverseModeGradient: { - IRBuilder<> Builder2(&IVI); - getReverseBuilder(Builder2); - - Value *orig_inserted = IVI.getInsertedValueOperand(); - Value *orig_agg = IVI.getAggregateOperand(); - - size_t size0 = 1; - if (orig_inserted->getType()->isSized()) - size0 = - (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - orig_inserted->getType()) + - 7) / - 8; - - if (!gutils->isConstantValue(orig_inserted)) { - auto TT = TR.query(orig_inserted); - - unsigned start = 0; - Value *dindex = nullptr; - - while (1) { - unsigned nextStart = size0; - - auto dt = TT[{-1}]; - for (size_t i = start; i < size0; ++i) { - auto nex = TT[{(int)i}]; - if ((nex == BaseType::Anything && dt.isFloat()) || - (dt == BaseType::Anything && nex.isFloat())) { - nextStart = i; - break; - } - bool Legal = true; - dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal); - if (!Legal) { - nextStart = i; - break; - } - } - Type *flt = dt.isFloat(); - if (!dt.isKnown()) { - bool found = false; - if (looseTypeAnalysis) { - if (orig_inserted->getType()->isFPOrFPVectorTy()) { - flt = orig_inserted->getType()->getScalarType(); - found = true; - } else if (orig_inserted->getType()->isIntOrIntVectorTy() || - orig_inserted->getType()->isPointerTy()) { - flt = nullptr; - found = true; - } - } - if (!found) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of insertvalue ins " << IVI - << " size: " << size0 << " TT: " << TT.str(); - EmitNoTypeError(str, IVI, gutils, Builder2); - } - } - - if (flt) { - if (!dindex) { - auto rule = [&](Value *prediff) { - return Builder2.CreateExtractValue(prediff, IVI.getIndices()); - }; - auto prediff = diffe(&IVI, Builder2); - dindex = applyChainRule(orig_inserted->getType(), Builder2, rule, - prediff); - } - - auto TT = TR.query(orig_inserted); - - ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_inserted, dindex, Builder2, flt, start, - nextStart - start); - } - if (nextStart == size0) - break; - start = nextStart; - } - } - - size_t size1 = 1; - if (orig_agg->getType()->isSized()) - size1 = - (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - orig_agg->getType()) + - 7) / - 8; - - if (!gutils->isConstantValue(orig_agg)) { - - auto TT = TR.query(orig_agg); - - unsigned start = 0; - - Value *dindex = nullptr; - - while (1) { - unsigned nextStart = size1; - - auto dt = TT[{-1}]; - for (size_t i = start; i < size1; ++i) { - auto nex = TT[{(int)i}]; - if ((nex == BaseType::Anything && dt.isFloat()) || - (dt == BaseType::Anything && nex.isFloat())) { - nextStart = i; - break; - } - bool Legal = true; - dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal); - if (!Legal) { - nextStart = i; - break; - } - } - Type *flt = dt.isFloat(); - if (!dt.isKnown()) { - bool found = false; - if (looseTypeAnalysis) { - if (orig_agg->getType()->isFPOrFPVectorTy()) { - flt = orig_agg->getType()->getScalarType(); - found = true; - } else if (orig_agg->getType()->isIntOrIntVectorTy() || - orig_agg->getType()->isPointerTy()) { - flt = nullptr; - found = true; - } - } - if (!found) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of insertvalue agg " << IVI - << " start: " << start << " size: " << size1 - << " TT: " << TT.str(); - EmitNoTypeError(str, IVI, gutils, Builder2); - } - } - - if (flt) { - if (!dindex) { - auto rule = [&](Value *prediff) { - return Builder2.CreateInsertValue( - prediff, Constant::getNullValue(orig_inserted->getType()), - IVI.getIndices()); - }; - auto prediff = diffe(&IVI, Builder2); - dindex = - applyChainRule(orig_agg->getType(), Builder2, rule, prediff); - } - ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_agg, dindex, Builder2, flt, start, - nextStart - start); - } - if (nextStart == size1) - break; - start = nextStart; - } - } - - setDiffe(&IVI, - Constant::getNullValue(gutils->getShadowType(IVI.getType())), - Builder2); - return; - } - case DerivativeMode::ReverseModePrimal: { - return; - } - } - } - - void getReverseBuilder(llvm::IRBuilder<> &Builder2, bool original = true) { - ((GradientUtils *)gutils)->getReverseBuilder(Builder2, original); - } - - void getForwardBuilder(llvm::IRBuilder<> &Builder2) { - ((GradientUtils *)gutils)->getForwardBuilder(Builder2); - } - - llvm::Value *diffe(llvm::Value *val, llvm::IRBuilder<> &Builder) { - assert(Mode != DerivativeMode::ReverseModePrimal); - return ((DiffeGradientUtils *)gutils)->diffe(val, Builder); - } - - void setDiffe(llvm::Value *val, llvm::Value *dif, - llvm::IRBuilder<> &Builder) { - assert(Mode != DerivativeMode::ReverseModePrimal); - ((DiffeGradientUtils *)gutils)->setDiffe(val, dif, Builder); - } - - /// Unwraps a vector derivative from its internal representation and applies a - /// function f to each element. Return values of f are collected and wrapped. - template - llvm::Value *applyChainRule(llvm::Type *diffType, llvm::IRBuilder<> &Builder, - Func rule, Args... args) { - return ((GradientUtils *)gutils) - ->applyChainRule(diffType, Builder, rule, args...); - } - - /// Unwraps a vector derivative from its internal representation and applies a - /// function f to each element. - template - void applyChainRule(llvm::IRBuilder<> &Builder, Func rule, Args... args) { - ((GradientUtils *)gutils)->applyChainRule(Builder, rule, args...); - } - - /// Unwraps an collection of constant vector derivatives from their internal - /// representations and applies a function f to each element. - template - void applyChainRule(llvm::ArrayRef diffs, - llvm::IRBuilder<> &Builder, Func rule) { - ((GradientUtils *)gutils)->applyChainRule(diffs, Builder, rule); - } - - bool shouldFree() { - assert(Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit); - return ((DiffeGradientUtils *)gutils)->FreeMemory; - } - - llvm::SmallVector - addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &Builder, - llvm::Type *T, llvm::Value *mask = nullptr) { - return ((DiffeGradientUtils *)gutils) - ->addToDiffe(val, dif, Builder, T, /*idxs*/ {}, mask); - } - - llvm::Value *lookup(llvm::Value *val, llvm::IRBuilder<> &Builder) { - return gutils->lookupM(val, Builder); - } - - void visitBinaryOperator(llvm::BinaryOperator &BO) { - eraseIfUnused(BO); - - if (BO.getOpcode() == llvm::Instruction::FDiv && - (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) && - !gutils->isConstantValue(&BO)) { - using namespace llvm; - // Required loopy phi = [in, BO, BO, ..., BO] - // 1) phi is only used in this B0 - // 2) BO dominates all latches - // 3) phi == B0 whenever not coming from preheader [implies 2] - // 4) [optional but done for ease] one exit to make it easier to - // calculation the product at that point - Value *orig_op0 = BO.getOperand(0); - if (auto P0 = dyn_cast(orig_op0)) { - LoopContext lc; - SmallVector activeUses; - for (auto u : P0->users()) { - if (!gutils->isConstantInstruction(cast(u))) { - activeUses.push_back(cast(u)); - } else if (retType == DIFFE_TYPE::OUT_DIFF && isa(u)) { - activeUses.push_back(cast(u)); - } - } - if (activeUses.size() == 1 && activeUses[0] == &BO && - gutils->getContext(gutils->getNewFromOriginal(P0->getParent()), - lc) && - gutils->getNewFromOriginal(P0->getParent()) == lc.header) { - SmallVector Latches; - gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches); - bool allIncoming = true; - for (auto Latch : Latches) { - if (&BO != P0->getIncomingValueForBlock(Latch)) { - allIncoming = false; - break; - } - } - if (allIncoming && lc.exitBlocks.size() == 1) { - - IRBuilder<> Builder2(&BO); - getReverseBuilder(Builder2); - - Value *orig_op1 = BO.getOperand(1); - - Value *dif1 = nullptr; - Value *idiff = diffe(&BO, Builder2); - - Type *addingType = BO.getType(); - - if (!gutils->isConstantValue(orig_op1)) { - IRBuilder<> EB(*lc.exitBlocks.begin()); - getReverseBuilder(EB, /*original=*/false); - Value *Pstart = P0->getIncomingValueForBlock( - gutils->getOriginalFromNew(lc.preheader)); - if (gutils->isConstantValue(Pstart)) { - Value *lop0 = lookup(gutils->getNewFromOriginal(&BO), EB); - Value *lop1 = - lookup(gutils->getNewFromOriginal(orig_op1), Builder2); - auto rule = [&](Value *idiff) { - auto res = Builder2.CreateFDiv( - Builder2.CreateFNeg(Builder2.CreateFMul(idiff, lop0)), - lop1); - if (EnzymeStrongZero) { - res = CreateSelect( - Builder2, - Builder2.CreateFCmpOEQ( - idiff, Constant::getNullValue(idiff->getType())), - idiff, res); - } - return res; - }; - dif1 = - applyChainRule(orig_op1->getType(), Builder2, rule, idiff); - } else { - auto product = gutils->getOrInsertTotalMultiplicativeProduct( - gutils->getNewFromOriginal(orig_op1), lc); - IRBuilder<> EB(*lc.exitBlocks.begin()); - getReverseBuilder(EB, /*original=*/false); - Value *s = lookup(gutils->getNewFromOriginal(Pstart), Builder2); - Value *lop0 = lookup(product, EB); - Value *lop1 = - lookup(gutils->getNewFromOriginal(orig_op1), Builder2); - auto rule = [&](Value *idiff) { - auto res = Builder2.CreateFDiv( - Builder2.CreateFNeg(Builder2.CreateFMul( - s, Builder2.CreateFDiv(idiff, lop0))), - lop1); - if (EnzymeStrongZero) { - res = CreateSelect( - Builder2, - Builder2.CreateFCmpOEQ( - idiff, Constant::getNullValue(idiff->getType())), - idiff, res); - } - return res; - }; - dif1 = - applyChainRule(orig_op1->getType(), Builder2, rule, idiff); - } - addToDiffe(orig_op1, dif1, Builder2, addingType); - } - return; - } - } - } - } - - { - using namespace llvm; - switch (BO.getOpcode()) { -#include "BinopDerivatives.inc" - default: - break; - } - } - - switch (Mode) { - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: - if (gutils->isConstantInstruction(&BO)) - return; - createBinaryOperatorAdjoint(BO); - break; - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: - case DerivativeMode::ForwardModeSplit: - createBinaryOperatorDual(BO); - break; - case DerivativeMode::ReverseModePrimal: - return; - } - } - - void createBinaryOperatorAdjoint(llvm::BinaryOperator &BO) { - if (gutils->isConstantInstruction(&BO)) { - return; - } - using namespace llvm; - - IRBuilder<> Builder2(&BO); - getReverseBuilder(Builder2); - - Value *orig_op0 = BO.getOperand(0); - Value *orig_op1 = BO.getOperand(1); - - Value *dif0 = nullptr; - Value *dif1 = nullptr; - Value *idiff = diffe(&BO, Builder2); - - Type *addingType = BO.getType(); - - switch (BO.getOpcode()) { - case Instruction::LShr: { - if (!gutils->isConstantValue(orig_op0)) { - if (auto ci = dyn_cast(orig_op1)) { - size_t size = 1; - if (orig_op0->getType()->isSized()) - size = (gutils->newFunc->getParent() - ->getDataLayout() - .getTypeSizeInBits(orig_op0->getType()) + - 7) / - 8; - - if (Type *flt = TR.addingType(size, orig_op0)) { - auto bits = gutils->newFunc->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(flt); - if (ci->getSExtValue() >= (int64_t)bits && - ci->getSExtValue() % bits == 0) { - auto rule = [&](Value *idiff) { - return Builder2.CreateShl(idiff, ci); - }; - dif0 = applyChainRule(orig_op0->getType(), Builder2, rule, idiff); - addingType = flt; - goto done; - } - } - } - } - if (looseTypeAnalysis) { - llvm::errs() << "warning: binary operator is integer and constant: " - << BO << "\n"; - // if loose type analysis, assume this integer and is constant - return; - } - goto def; - } - case Instruction::And: { - // If & against 0b10000000000 and a float the result is 0 - auto &dl = gutils->oldFunc->getParent()->getDataLayout(); - auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - - auto FT = TR.query(&BO).IsAllFloat(size, dl); - auto eFT = FT; - if (FT) - for (int i = 0; i < 2; ++i) { - auto CI = dyn_cast(BO.getOperand(i)); - if (CI && dl.getTypeSizeInBits(eFT) == - dl.getTypeSizeInBits(CI->getType())) { - if (eFT->isDoubleTy() && CI->getValue() == -134217728) { - setDiffe( - &BO, - Constant::getNullValue(gutils->getShadowType(BO.getType())), - Builder2); - // Derivative is zero (equivalent to rounding as just chopping off - // bits of mantissa), no update - return; - } - } - } - if (looseTypeAnalysis) { - llvm::errs() << "warning: binary operator is integer and constant: " - << BO << "\n"; - // if loose type analysis, assume this integer and is constant - return; - } - goto def; - } - case Instruction::Xor: { - auto &dl = gutils->oldFunc->getParent()->getDataLayout(); - auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - - auto FT = TR.query(&BO).IsAllFloat(size, dl); - auto eFT = FT; - // If ^ against 0b10000000000 and a float the result is a float - if (FT) - for (int i = 0; i < 2; ++i) { - if (containsOnlyAtMostTopBit(BO.getOperand(i), eFT, dl, &FT)) { - setDiffe( - &BO, - Constant::getNullValue(gutils->getShadowType(BO.getType())), - Builder2); - auto isZero = Builder2.CreateICmpEQ( - lookup(gutils->getNewFromOriginal(BO.getOperand(i)), Builder2), - Constant::getNullValue(BO.getType())); - auto rule = [&](Value *idiff) { - auto ext = Builder2.CreateBitCast(idiff, FT); - auto neg = Builder2.CreateFNeg(ext); - neg = CreateSelect(Builder2, isZero, ext, neg); - neg = Builder2.CreateBitCast(neg, BO.getType()); - return neg; - }; - auto bc = applyChainRule(BO.getOperand(1 - i)->getType(), Builder2, - rule, idiff); - addToDiffe(BO.getOperand(1 - i), bc, Builder2, FT); - return; - } - } - if (looseTypeAnalysis) { - llvm::errs() << "warning: binary operator is integer and constant: " - << BO << "\n"; - // if loose type analysis, assume this integer and is constant - return; - } - goto def; - } - case Instruction::Or: { - auto &dl = gutils->oldFunc->getParent()->getDataLayout(); - auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - - auto FT = TR.query(&BO).IsAllFloat(size, dl); - auto eFT = FT; - // If & against 0b10000000000 and a float the result is a float - if (FT) - for (int i = 0; i < 2; ++i) { - auto CI = dyn_cast(BO.getOperand(i)); - if (auto CV = dyn_cast(BO.getOperand(i))) { - CI = dyn_cast_or_null(CV->getSplatValue()); - FT = VectorType::get(FT, CV->getType()->getElementCount()); - } - if (auto CV = dyn_cast(BO.getOperand(i))) { - CI = dyn_cast_or_null(CV->getSplatValue()); - FT = VectorType::get(FT, CV->getType()->getElementCount()); - } - if (CI && dl.getTypeSizeInBits(eFT) == - dl.getTypeSizeInBits(CI->getType())) { - auto AP = CI->getValue(); - bool validXor = false; -#if LLVM_VERSION_MAJOR > 16 - if (AP.isZero()) -#else - if (AP.isNullValue()) -#endif - { - validXor = true; - } else if ( - !AP.isNegative() && - ((FT->isFloatTy() -#if LLVM_VERSION_MAJOR > 16 - && (AP & ~0b01111111100000000000000000000000ULL).isZero() -#else - && (AP & ~0b01111111100000000000000000000000ULL).isNullValue() -#endif - ) || - (FT->isDoubleTy() -#if LLVM_VERSION_MAJOR > 16 - && - (AP & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isZero() -#else - && - (AP & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isNullValue() -#endif - ))) { - validXor = true; - } - if (validXor) { - setDiffe( - &BO, - Constant::getNullValue(gutils->getShadowType(BO.getType())), - Builder2); - - auto arg = lookup( - gutils->getNewFromOriginal(BO.getOperand(1 - i)), Builder2); - - auto rule = [&](Value *idiff) { - auto prev = Builder2.CreateOr(arg, BO.getOperand(i)); - prev = Builder2.CreateSub(prev, arg, "", /*NUW*/ true, - /*NSW*/ false); - uint64_t num = 0; - if (FT->isFloatTy()) { - num = 127ULL << 23; - } else { - assert(FT->isDoubleTy()); - num = 1023ULL << 52; - } - prev = Builder2.CreateAdd( - prev, ConstantInt::get(prev->getType(), num, false), "", - /*NUW*/ true, /*NSW*/ true); - prev = Builder2.CreateBitCast( - checkedMul(Builder2, Builder2.CreateBitCast(idiff, FT), - Builder2.CreateBitCast(prev, FT)), - prev->getType()); - return prev; - }; - - Value *prev = applyChainRule(BO.getOperand(1 - i)->getType(), - Builder2, rule, idiff); - addToDiffe(BO.getOperand(1 - i), prev, Builder2, FT); - return; - } - } - } - if (looseTypeAnalysis) { - llvm::errs() << "warning: binary operator is integer and constant: " - << BO << "\n"; - // if loose type analysis, assume this integer or is constant - return; - } - goto def; - } - case Instruction::SDiv: - case Instruction::Shl: - case Instruction::Mul: - case Instruction::Sub: - case Instruction::Add: { - if (looseTypeAnalysis) { - llvm::errs() - << "warning: binary operator is integer and assumed constant: " - << BO << "\n"; - // if loose type analysis, assume this integer add is constant - return; - } - goto def; - } - default: - def:; - std::string s; - llvm::raw_string_ostream ss(s); - ss << *gutils->oldFunc << "\n"; - for (auto &arg : gutils->oldFunc->args()) { - ss << " constantarg[" << arg << "] = " << gutils->isConstantValue(&arg) - << " type: " << TR.query(&arg).str() << " - vals: {"; - for (auto v : TR.knownIntegralValues(&arg)) - ss << v << ","; - ss << "}\n"; - } - for (auto &BB : *gutils->oldFunc) - for (auto &I : BB) { - ss << " constantinst[" << I - << "] = " << gutils->isConstantInstruction(&I) - << " val:" << gutils->isConstantValue(&I) - << " type: " << TR.query(&I).str() << "\n"; - } - ss << "cannot handle unknown binary operator: " << BO << "\n"; - EmitNoDerivativeError(ss.str(), BO, gutils, Builder2); - } - - done:; - if (dif0 || dif1) - setDiffe(&BO, Constant::getNullValue(gutils->getShadowType(BO.getType())), - Builder2); - if (dif0) - addToDiffe(orig_op0, dif0, Builder2, addingType); - if (dif1) - addToDiffe(orig_op1, dif1, Builder2, addingType); - } - - void createBinaryOperatorDual(llvm::BinaryOperator &BO) { - using namespace llvm; - - if (gutils->isConstantInstruction(&BO)) { - forwardModeInvertedPointerFallback(BO); - return; - } - - IRBuilder<> Builder2(&BO); - getForwardBuilder(Builder2); - - Value *orig_op0 = BO.getOperand(0); - Value *orig_op1 = BO.getOperand(1); - - bool constantval0 = gutils->isConstantValue(orig_op0); - bool constantval1 = gutils->isConstantValue(orig_op1); - - switch (BO.getOpcode()) { - case Instruction::And: { - // If & against 0b10000000000 and a float the result is 0 - auto &dl = gutils->oldFunc->getParent()->getDataLayout(); - auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - Type *diffTy = gutils->getShadowType(BO.getType()); - - auto FT = TR.query(&BO).IsAllFloat(size, dl); - auto eFT = FT; - if (FT) - for (int i = 0; i < 2; ++i) { - auto CI = dyn_cast(BO.getOperand(i)); - if (CI && dl.getTypeSizeInBits(eFT) == - dl.getTypeSizeInBits(CI->getType())) { - if (eFT->isDoubleTy() && CI->getValue() == -134217728) { - setDiffe(&BO, Constant::getNullValue(diffTy), Builder2); - // Derivative is zero (equivalent to rounding as just chopping off - // bits of mantissa), no update - return; - } - } - } - if (looseTypeAnalysis) { - forwardModeInvertedPointerFallback(BO); - llvm::errs() << "warning: binary operator is integer and constant: " - << BO << "\n"; - // if loose type analysis, assume this integer and is constant - return; - } - goto def; - } - case Instruction::Xor: { - auto &dl = gutils->oldFunc->getParent()->getDataLayout(); - auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - - auto FT = TR.query(&BO).IsAllFloat(size, dl); - auto eFT = FT; - - Value *dif[2] = {constantval0 ? nullptr : diffe(orig_op0, Builder2), - constantval1 ? nullptr : diffe(orig_op1, Builder2)}; - - for (int i = 0; i < 2; ++i) { - if (containsOnlyAtMostTopBit(BO.getOperand(i), eFT, dl, &FT) && - dif[1 - i] && !dif[i]) { - auto isZero = Builder2.CreateICmpEQ( - gutils->getNewFromOriginal(BO.getOperand(i)), - Constant::getNullValue(BO.getType())); - auto rule = [&](Value *idiff) { - auto ext = Builder2.CreateBitCast(idiff, FT); - auto neg = Builder2.CreateFNeg(ext); - neg = CreateSelect(Builder2, isZero, ext, neg); - neg = Builder2.CreateBitCast(neg, BO.getType()); - return neg; - }; - auto bc = applyChainRule(BO.getOperand(1 - i)->getType(), Builder2, - rule, dif[1 - i]); - setDiffe(&BO, bc, Builder2); - return; - } - } - if (looseTypeAnalysis) { - forwardModeInvertedPointerFallback(BO); - llvm::errs() << "warning: binary operator is integer and constant: " - << BO << "\n"; - // if loose type analysis, assume this integer and is constant - return; - } - goto def; - } - case Instruction::Or: { - auto &dl = gutils->oldFunc->getParent()->getDataLayout(); - auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - - Value *dif[2] = {constantval0 ? nullptr : diffe(orig_op0, Builder2), - constantval1 ? nullptr : diffe(orig_op1, Builder2)}; - - auto FT = TR.query(&BO).IsAllFloat(size, dl); - auto eFT = FT; - // If & against 0b10000000000 and a float the result is a float - if (FT) - for (int i = 0; i < 2; ++i) { - auto CI = dyn_cast(BO.getOperand(i)); - if (auto CV = dyn_cast(BO.getOperand(i))) { - CI = dyn_cast_or_null(CV->getSplatValue()); - FT = VectorType::get(FT, CV->getType()->getElementCount()); - } - if (auto CV = dyn_cast(BO.getOperand(i))) { - CI = dyn_cast_or_null(CV->getSplatValue()); - } - if (CI && dl.getTypeSizeInBits(eFT) == - dl.getTypeSizeInBits(CI->getType())) { - auto AP = CI->getValue(); - bool validXor = false; -#if LLVM_VERSION_MAJOR > 16 - if (AP.isZero()) -#else - if (AP.isNullValue()) -#endif - { - validXor = true; - } else if ( - !AP.isNegative() && - ((FT->isFloatTy() -#if LLVM_VERSION_MAJOR > 16 - && (AP & ~0b01111111100000000000000000000000ULL).isZero() -#else - && (AP & ~0b01111111100000000000000000000000ULL).isNullValue() -#endif - ) || - (FT->isDoubleTy() -#if LLVM_VERSION_MAJOR > 16 - && - (AP & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isZero() -#else - && - (AP & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isNullValue() -#endif - ))) { - validXor = true; - } - if (validXor) { - auto rule = [&](Value *difi) { - auto arg = gutils->getNewFromOriginal(BO.getOperand(1 - i)); - auto prev = Builder2.CreateOr(arg, BO.getOperand(i)); - prev = Builder2.CreateSub(prev, arg, "", /*NUW*/ true, - /*NSW*/ false); - uint64_t num = 0; - if (FT->isFloatTy()) { - num = 127ULL << 23; - } else { - assert(FT->isDoubleTy()); - num = 1023ULL << 52; - } - prev = Builder2.CreateAdd( - prev, ConstantInt::get(prev->getType(), num, false), "", - /*NUW*/ true, /*NSW*/ true); - prev = Builder2.CreateBitCast( - checkedMul(Builder2, Builder2.CreateBitCast(difi, FT), - Builder2.CreateBitCast(prev, FT)), - prev->getType()); - - return prev; - }; - - auto diffe = - applyChainRule(BO.getType(), Builder2, rule, dif[1 - i]); - setDiffe(&BO, diffe, Builder2); - return; - } - } - } - if (looseTypeAnalysis) { - forwardModeInvertedPointerFallback(BO); - llvm::errs() << "warning: binary operator is integer and constant: " - << BO << "\n"; - // if loose type analysis, assume this integer or is constant - return; - } - goto def; - } - case Instruction::Shl: - case Instruction::Mul: - case Instruction::Sub: - case Instruction::Add: { - if (looseTypeAnalysis) { - forwardModeInvertedPointerFallback(BO); - llvm::errs() << "warning: binary operator is integer and constant: " - << BO << "\n"; - // if loose type analysis, assume this integer add is constant - return; - } - goto def; - } - default: - def:; - std::string s; - llvm::raw_string_ostream ss(s); - ss << *gutils->oldFunc << "\n"; - for (auto &arg : gutils->oldFunc->args()) { - ss << " constantarg[" << arg << "] = " << gutils->isConstantValue(&arg) - << " type: " << TR.query(&arg).str() << " - vals: {"; - for (auto v : TR.knownIntegralValues(&arg)) - ss << v << ","; - ss << "}\n"; - } - for (auto &BB : *gutils->oldFunc) - for (auto &I : BB) { - ss << " constantinst[" << I - << "] = " << gutils->isConstantInstruction(&I) - << " val:" << gutils->isConstantValue(&I) - << " type: " << TR.query(&I).str() << "\n"; - } - ss << "cannot handle unknown binary operator: " << BO << "\n"; - auto rval = EmitNoDerivativeError(ss.str(), BO, gutils, Builder2); - if (!rval) - rval = Constant::getNullValue(gutils->getShadowType(BO.getType())); - auto ifound = gutils->invertedPointers.find(&BO); - if (!gutils->isConstantValue(&BO)) { - if (ifound != gutils->invertedPointers.end()) { - auto placeholder = cast(&*ifound->second); - gutils->invertedPointers.erase(ifound); - gutils->replaceAWithB(placeholder, rval); - gutils->erase(placeholder); - gutils->invertedPointers.insert(std::make_pair( - (const Value *)&BO, InvertedPointerVH(gutils, rval))); - } - } else { - assert(ifound == gutils->invertedPointers.end()); - } - break; - } - } - - void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); } - - void visitMemSetCommon(llvm::CallInst &MS) { - using namespace llvm; - - IRBuilder<> BuilderZ(&MS); - getForwardBuilder(BuilderZ); - - IRBuilder<> Builder2(&MS); - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) - getReverseBuilder(Builder2); - - bool forceErase = false; - if (Mode == DerivativeMode::ReverseModeGradient) { - for (const auto &pair : gutils->rematerializableAllocations) { - if (pair.second.stores.count(&MS) && pair.second.LI) { - forceErase = true; - } - } - } - if (forceErase) - eraseIfUnused(MS, /*erase*/ true, /*check*/ false); - else - eraseIfUnused(MS); - - Value *orig_op0 = MS.getArgOperand(0); - Value *orig_op1 = MS.getArgOperand(1); - - // If constant destination then no operation needs doing - if (gutils->isConstantValue(orig_op0)) { - return; - } - - bool activeValToSet = !gutils->isConstantValue(orig_op1); - if (activeValToSet) - if (auto CI = dyn_cast(orig_op1)) - if (CI->isZero()) - activeValToSet = false; - if (activeValToSet) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "couldn't handle non constant inst in memset to " - "propagate differential to\n" - << MS; - EmitNoDerivativeError(ss.str(), MS, gutils, BuilderZ); - } - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - Value *op0 = gutils->invertPointerM(orig_op0, BuilderZ); - Value *op1 = gutils->getNewFromOriginal(MS.getArgOperand(1)); - Value *op2 = gutils->getNewFromOriginal(MS.getArgOperand(2)); - Value *op3 = nullptr; - if (3 < MS.arg_size()) { - op3 = gutils->getNewFromOriginal(MS.getOperand(3)); - } - - auto Defs = - gutils->getInvertedBundles(&MS, - {ValueType::Shadow, ValueType::Primal, - ValueType::Primal, ValueType::Primal}, - BuilderZ, /*lookup*/ false); - - auto funcName = getFuncNameFromCall(&MS); - applyChainRule( - BuilderZ, - [&](Value *op0) { - SmallVector args = {op0, op1, op2}; - if (op3) - args.push_back(op3); - - CallInst *cal; - if (startsWith(funcName, "memset_pattern")) - cal = Builder2.CreateMemSet( - op0, ConstantInt::get(Builder2.getInt8Ty(), 0), op2, {}); - else - cal = BuilderZ.CreateCall(MS.getCalledFunction(), args, Defs); - - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - cal->copyMetadata(MS, ToCopy2); - if (auto m = hasMetadata(&MS, "enzyme_zerostack")) - cal->setMetadata("enzyme_zerostack", m); - - if (startsWith(funcName, "memset_pattern")) { - AttributeList NewAttrs; - for (auto idx : - {AttributeList::ReturnIndex, AttributeList::FunctionIndex, - AttributeList::FirstArgIndex}) - for (auto attr : MS.getAttributes().getAttributes(idx)) - NewAttrs = - NewAttrs.addAttributeAtIndex(MS.getContext(), idx, attr); - cal->setAttributes(NewAttrs); - } else - cal->setAttributes(MS.getAttributes()); - cal->setCallingConv(MS.getCallingConv()); - cal->setTailCallKind(MS.getTailCallKind()); - cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc())); - }, - op0); - return; - } - - bool backwardsShadow = false; - bool forwardsShadow = true; - for (auto pair : gutils->backwardsOnlyShadows) { - if (pair.second.stores.count(&MS)) { - backwardsShadow = true; - forwardsShadow = pair.second.primalInitialize; - if (auto inst = dyn_cast(pair.first)) - if (!forwardsShadow && pair.second.LI && - pair.second.LI->contains(inst->getParent())) - backwardsShadow = false; - } - } - - size_t size = 1; - if (auto ci = dyn_cast(MS.getOperand(2))) { - size = ci->getLimitedValue(); - } - - // TODO note that we only handle memset of ONE type (aka memset of {int, - // double} not allowed) - - if (size == 0) { - llvm::errs() << MS << "\n"; - } - assert(size != 0); - - // Offsets of the form Optional, segment start, segment size - std::vector> toIterate; - - // Special handling mechanism to bypass TA limitations by supporting - // arbitrary sized types. - if (auto MD = hasMetadata(&MS, "enzyme_truetype")) { - toIterate = parseTrueType(MD, Mode, false); - } else { - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - auto vd = TR.query(MS.getOperand(0)).Data0().ShiftIndices(DL, 0, size, 0); - - if (!vd.isKnownPastPointer()) { - // If unknown type results, and zeroing known undef allocation, consider - // integers - if (auto CI = dyn_cast(MS.getOperand(1))) - if (CI->isZero()) { - auto root = getBaseObject(MS.getOperand(0)); - bool writtenTo = false; - bool undefMemory = - isa(root) || isAllocationCall(root, gutils->TLI); - if (auto arg = dyn_cast(root)) - if (arg->hasStructRetAttr()) - undefMemory = true; - if (undefMemory) { - Instruction *cur = MS.getPrevNode(); - while (cur) { - if (cur == root) - break; - if (auto MCI = dyn_cast(MS.getOperand(2))) { - if (auto II = dyn_cast(cur)) { - // If the start of the lifetime for more memory than being - // memset, its valid. - if (II->getIntrinsicID() == Intrinsic::lifetime_start) { - if (getBaseObject(II->getOperand(1)) == root) { - if (auto CI2 = - dyn_cast(II->getOperand(0))) { - if (MCI->getValue().ule(CI2->getValue())) - break; - } - } - cur = cur->getPrevNode(); - continue; - } - } - } - if (cur->mayWriteToMemory()) { - writtenTo = true; - break; - } - cur = cur->getPrevNode(); - } - - if (!writtenTo) { - vd = TypeTree(BaseType::Pointer); - vd.insert({-1}, BaseType::Integer); - } - } - } - } - - if (!vd.isKnownPastPointer()) { - // If unknown type results, consider the intersection of all incoming. - if (isa(MS.getOperand(0)) || - isa(MS.getOperand(0))) { - SmallVector todo = {MS.getOperand(0)}; - bool set = false; - SmallSet seen; - TypeTree vd2; - while (todo.size()) { - Value *cur = todo.back(); - todo.pop_back(); - if (seen.count(cur)) - continue; - seen.insert(cur); - if (auto PN = dyn_cast(cur)) { - for (size_t i = 0, end = PN->getNumIncomingValues(); i < end; - i++) { - todo.push_back(PN->getIncomingValue(i)); - } - continue; - } - if (auto S = dyn_cast(cur)) { - todo.push_back(S->getTrueValue()); - todo.push_back(S->getFalseValue()); - continue; - } - if (auto CE = dyn_cast(cur)) { - if (CE->isCast()) { - todo.push_back(CE->getOperand(0)); - continue; - } - } - if (auto CI = dyn_cast(cur)) { - todo.push_back(CI->getOperand(0)); - continue; - } - if (isa(cur)) - continue; - if (auto CI = dyn_cast(cur)) - if (CI->isZero()) - continue; - auto curTT = TR.query(cur).Data0().ShiftIndices(DL, 0, size, 0); - if (!set) - vd2 = curTT; - else - vd2 &= curTT; - set = true; - } - vd = vd2; - } - } - if (!vd.isKnownPastPointer()) { - if (looseTypeAnalysis) { -#if LLVM_VERSION_MAJOR < 17 - if (auto CI = dyn_cast(MS.getOperand(0))) { - if (auto PT = dyn_cast(CI->getSrcTy())) { - auto ET = PT->getPointerElementType(); - while (1) { - if (auto ST = dyn_cast(ET)) { - if (ST->getNumElements()) { - ET = ST->getElementType(0); - continue; - } - } - if (auto AT = dyn_cast(ET)) { - ET = AT->getElementType(); - continue; - } - break; - } - if (ET->isFPOrFPVectorTy()) { - vd = TypeTree(ConcreteType(ET->getScalarType())).Only(0, &MS); - goto known; - } - if (ET->isPointerTy()) { - vd = TypeTree(BaseType::Pointer).Only(0, &MS); - goto known; - } - if (ET->isIntOrIntVectorTy()) { - vd = TypeTree(BaseType::Integer).Only(0, &MS); - goto known; - } - } - } -#endif - if (auto gep = dyn_cast(MS.getOperand(0))) { - if (auto AT = dyn_cast(gep->getSourceElementType())) { - if (AT->getElementType()->isIntegerTy()) { - vd = TypeTree(BaseType::Integer).Only(0, &MS); - goto known; - } - } - } - EmitWarning("CannotDeduceType", MS, - "failed to deduce type of memset ", MS); - vd = TypeTree(BaseType::Pointer).Only(0, &MS); - goto known; - } - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of memset " << MS; - EmitNoTypeError(str, MS, gutils, BuilderZ); - return; - } - known:; - { - unsigned start = 0; - while (1) { - unsigned nextStart = size; - - auto dt = vd[{-1}]; - for (size_t i = start; i < size; ++i) { - bool Legal = true; - dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal); - if (!Legal) { - nextStart = i; - break; - } - } - if (!dt.isKnown()) { - TR.dump(); - llvm::errs() << " vd:" << vd.str() << " start:" << start - << " size: " << size << " dt:" << dt.str() << "\n"; - } - assert(dt.isKnown()); - toIterate.emplace_back(dt.isFloat(), start, nextStart - start); - - if (nextStart == size) - break; - start = nextStart; - } - } - } - -#if 0 - unsigned dstalign = dstAlign.valueOrOne().value(); - unsigned srcalign = srcAlign.valueOrOne().value(); -#endif - - Value *op1 = gutils->getNewFromOriginal(MS.getArgOperand(1)); - Value *new_size = gutils->getNewFromOriginal(MS.getArgOperand(2)); - Value *op3 = nullptr; - if (3 < MS.arg_size()) { - op3 = gutils->getNewFromOriginal(MS.getOperand(3)); - } - - for (auto &&[secretty_ref, seg_start_ref, seg_size_ref] : toIterate) { - auto secretty = secretty_ref; - auto seg_start = seg_start_ref; - auto seg_size = seg_size_ref; - - Value *length = new_size; - if (seg_start != std::get<1>(toIterate.back())) { - length = ConstantInt::get(new_size->getType(), seg_start + seg_size); - } - if (seg_start != 0) - length = BuilderZ.CreateSub( - length, ConstantInt::get(new_size->getType(), seg_start)); - -#if 0 - unsigned subdstalign = dstalign; - // todo make better alignment calculation - if (dstalign != 0) { - if (start % dstalign != 0) { - dstalign = 1; - } - } - unsigned subsrcalign = srcalign; - // todo make better alignment calculation - if (srcalign != 0) { - if (start % srcalign != 0) { - srcalign = 1; - } - } -#endif - - Value *shadow_dst = gutils->invertPointerM(MS.getOperand(0), BuilderZ); - - // TODO ponder forward split mode - if (!secretty && - ((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) || - (Mode == DerivativeMode::ReverseModeCombined && forwardsShadow) || - (Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) || - (Mode == DerivativeMode::ForwardModeSplit && backwardsShadow))) { - auto Defs = - gutils->getInvertedBundles(&MS, - {ValueType::Shadow, ValueType::Primal, - ValueType::Primal, ValueType::Primal}, - BuilderZ, /*lookup*/ false); - auto rule = [&](Value *op0) { - if (seg_start != 0) { - Value *idxs[] = {ConstantInt::get( - Type::getInt32Ty(op0->getContext()), seg_start)}; - op0 = BuilderZ.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()), - op0, idxs); - } - SmallVector args = {op0, op1, length}; - if (op3) - args.push_back(op3); - auto cal = BuilderZ.CreateCall(MS.getCalledFunction(), args, Defs); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - if (auto m = hasMetadata(&MS, "enzyme_zerostack")) - cal->setMetadata("enzyme_zerostack", m); - cal->copyMetadata(MS, ToCopy2); - cal->setAttributes(MS.getAttributes()); - cal->setCallingConv(MS.getCallingConv()); - cal->setTailCallKind(MS.getTailCallKind()); - cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc())); - }; - - applyChainRule(BuilderZ, rule, shadow_dst); - } - if (secretty && (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined)) { - - auto Defs = - gutils->getInvertedBundles(&MS, - {ValueType::Shadow, ValueType::Primal, - ValueType::Primal, ValueType::Primal}, - BuilderZ, /*lookup*/ true); - Value *op1l = gutils->lookupM(op1, Builder2); - Value *op3l = op3; - if (op3l) - op3l = gutils->lookupM(op3l, BuilderZ); - length = gutils->lookupM(length, Builder2); - auto rule = [&](Value *op0) { - if (seg_start != 0) { - Value *idxs[] = {ConstantInt::get( - Type::getInt32Ty(op0->getContext()), seg_start)}; - op0 = Builder2.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()), - op0, idxs); - } - SmallVector args = {op0, op1l, length}; - if (op3l) - args.push_back(op3l); - CallInst *cal; - auto funcName = getFuncNameFromCall(&MS); - if (startsWith(funcName, "memset_pattern")) - cal = Builder2.CreateMemSet( - op0, ConstantInt::get(Builder2.getInt8Ty(), 0), length, {}); - else - cal = Builder2.CreateCall(MS.getCalledFunction(), args, Defs); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - cal->copyMetadata(MS, ToCopy2); - if (auto m = hasMetadata(&MS, "enzyme_zerostack")) - cal->setMetadata("enzyme_zerostack", m); - - if (startsWith(funcName, "memset_pattern")) { - AttributeList NewAttrs; - for (auto idx : - {AttributeList::ReturnIndex, AttributeList::FunctionIndex, - AttributeList::FirstArgIndex}) - for (auto attr : MS.getAttributes().getAttributes(idx)) - NewAttrs = - NewAttrs.addAttributeAtIndex(MS.getContext(), idx, attr); - cal->setAttributes(NewAttrs); - } else - cal->setAttributes(MS.getAttributes()); - cal->setCallingConv(MS.getCallingConv()); - cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc())); - }; - - applyChainRule(Builder2, rule, gutils->lookupM(shadow_dst, Builder2)); - } - } - } - - void visitMemTransferInst(llvm::MemTransferInst &MTI) { - using namespace llvm; - Value *isVolatile = gutils->getNewFromOriginal(MTI.getOperand(3)); - auto srcAlign = MTI.getSourceAlign(); - auto dstAlign = MTI.getDestAlign(); - visitMemTransferCommon(MTI.getIntrinsicID(), srcAlign, dstAlign, MTI, - MTI.getOperand(0), MTI.getOperand(1), - gutils->getNewFromOriginal(MTI.getOperand(2)), - isVolatile); - } - - void visitMemTransferCommon(llvm::Intrinsic::ID ID, llvm::MaybeAlign srcAlign, - llvm::MaybeAlign dstAlign, llvm::CallInst &MTI, - llvm::Value *orig_dst, llvm::Value *orig_src, - llvm::Value *new_size, llvm::Value *isVolatile) { - using namespace llvm; - - if (gutils->isConstantValue(MTI.getOperand(0))) { - eraseIfUnused(MTI); - return; - } - - if (unnecessaryStores.count(&MTI)) { - eraseIfUnused(MTI); - return; - } - - // memcpy of size 1 cannot move differentiable data [single byte copy] - if (auto ci = dyn_cast(new_size)) { - if (ci->getValue() == 1) { - eraseIfUnused(MTI); - return; - } - } - - // copying into nullptr is invalid (not sure why it exists here), but we - // shouldn't do it in reverse pass or shadow - if (isa(orig_dst) || - TR.query(orig_dst)[{-1}] == BaseType::Anything) { - eraseIfUnused(MTI); - return; - } - - size_t size = 1; - if (auto ci = dyn_cast(new_size)) { - size = ci->getLimitedValue(); - } - - // TODO note that we only handle memcpy/etc of ONE type (aka memcpy of {int, - // double} not allowed) - if (size == 0) { - eraseIfUnused(MTI); - return; - } - - if ((Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) && - gutils->isConstantValue(orig_dst)) { - eraseIfUnused(MTI); - return; - } - - // Offsets of the form Optional, segment start, segment size - std::vector> toIterate; - IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&MTI)); - - // Special handling mechanism to bypass TA limitations by supporting - // arbitrary sized types. - if (auto MD = hasMetadata(&MTI, "enzyme_truetype")) { - toIterate = parseTrueType(MD, Mode, - !gutils->isConstantValue(orig_src) && - !gutils->runtimeActivity); - } else { - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - auto vd = TR.query(orig_dst).Data0().ShiftIndices(DL, 0, size, 0); - vd |= TR.query(orig_src).Data0().PurgeAnything().ShiftIndices(DL, 0, size, - 0); - for (size_t i = 0; i < MTI.getNumOperands(); i++) - if (MTI.getOperand(i) == orig_dst) - if (MTI.getAttributes().hasParamAttr(i, "enzyme_type")) { - auto attr = MTI.getAttributes().getParamAttr(i, "enzyme_type"); - auto TT = - TypeTree::parse(attr.getValueAsString(), MTI.getContext()); - vd |= TT.Data0().ShiftIndices(DL, 0, size, 0); - break; - } - - bool errorIfNoType = true; - if ((Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) && - (!gutils->isConstantValue(orig_src) && !gutils->runtimeActivity)) { - errorIfNoType = false; - } - - if (!vd.isKnownPastPointer()) { - if (looseTypeAnalysis) { - for (auto val : {orig_dst, orig_src}) { -#if LLVM_VERSION_MAJOR < 17 - if (auto CI = dyn_cast(val)) { - if (auto PT = dyn_cast(CI->getSrcTy())) { - auto ET = PT->getPointerElementType(); - while (1) { - if (auto ST = dyn_cast(ET)) { - if (ST->getNumElements()) { - ET = ST->getElementType(0); - continue; - } - } - if (auto AT = dyn_cast(ET)) { - ET = AT->getElementType(); - continue; - } - break; - } - if (ET->isFPOrFPVectorTy()) { - vd = - TypeTree(ConcreteType(ET->getScalarType())).Only(0, &MTI); - goto known; - } - if (ET->isPointerTy()) { - vd = TypeTree(BaseType::Pointer).Only(0, &MTI); - goto known; - } - if (ET->isIntOrIntVectorTy()) { - vd = TypeTree(BaseType::Integer).Only(0, &MTI); - goto known; - } - } - } -#endif - if (auto gep = dyn_cast(val)) { - if (auto AT = dyn_cast(gep->getSourceElementType())) { - if (AT->getElementType()->isIntegerTy()) { - vd = TypeTree(BaseType::Integer).Only(0, &MTI); - goto known; - } - } - } - } - // If the type is known, but outside of the known range - // (but the memcpy size is a variable), attempt to use - // the first type out of range as the memcpy type. - if (size == 1 && !isa(new_size)) { - for (auto ptr : {orig_dst, orig_src}) { - vd = TR.query(ptr).Data0().ShiftIndices(DL, 0, -1, 0); - if (vd.isKnownPastPointer()) { - ConcreteType mv(BaseType::Unknown); - size_t minInt = 0xFFFFFFFF; - for (const auto &pair : vd.getMapping()) { - if (pair.first.size() != 1) - continue; - if (minInt < (size_t)pair.first[0]) - continue; - minInt = pair.first[0]; - mv = pair.second; - } - assert(mv != BaseType::Unknown); - vd.insert({0}, mv); - goto known; - } - } - } - if (errorIfNoType) - EmitWarning("CannotDeduceType", MTI, - "failed to deduce type of copy ", MTI); - vd = TypeTree(BaseType::Pointer).Only(0, &MTI); - goto known; - } - if (errorIfNoType) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of copy " << MTI; - EmitNoTypeError(str, MTI, gutils, BuilderZ); - vd = TypeTree(BaseType::Integer).Only(0, &MTI); - } else { - vd = TypeTree(BaseType::Pointer).Only(0, &MTI); - } - } - - known:; - { - - unsigned start = 0; - while (1) { - unsigned nextStart = size; - - auto dt = vd[{-1}]; - for (size_t i = start; i < size; ++i) { - bool Legal = true; - auto tmp = dt; - auto next = vd[{(int)i}]; - tmp.checkedOrIn(next, /*PointerIntSame*/ true, Legal); - // Prevent fusion of {Anything, Float} since anything is an int rule - // but float requires zeroing. - if ((dt == BaseType::Anything && - (next != BaseType::Anything && next.isKnown())) || - (next == BaseType::Anything && - (dt != BaseType::Anything && dt.isKnown()))) - Legal = false; - if (!Legal) { - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - // if both are floats (of any type), forward mode is the same. - // + [potentially zero if const, otherwise copy] - // if both are int/pointer (of any type), also the same - // + copy - // if known non-constant, also the same - // + copy - if ((dt.isFloat() == nullptr) == - (vd[{(int)i}].isFloat() == nullptr)) { - Legal = true; - } - if (!gutils->isConstantValue(orig_src) && - !gutils->runtimeActivity) { - Legal = true; - } - } - if (!Legal) { - nextStart = i; - break; - } - } else - dt = tmp; - } - if (!dt.isKnown()) { - TR.dump(); - llvm::errs() << " vd:" << vd.str() << " start:" << start - << " size: " << size << " dt:" << dt.str() << "\n"; - } - assert(dt.isKnown()); - toIterate.emplace_back(dt.isFloat(), start, nextStart - start); - - if (nextStart == size) - break; - start = nextStart; - } - } - } - - // llvm::errs() << "MIT: " << MTI << "|size: " << size << " vd: " << - // vd.str() << "\n"; - - unsigned dstalign = dstAlign.valueOrOne().value(); - unsigned srcalign = srcAlign.valueOrOne().value(); - - bool backwardsShadow = false; - bool forwardsShadow = true; - for (auto pair : gutils->backwardsOnlyShadows) { - if (pair.second.stores.count(&MTI)) { - backwardsShadow = true; - forwardsShadow = pair.second.primalInitialize; - if (auto inst = dyn_cast(pair.first)) - if (!forwardsShadow && pair.second.LI && - pair.second.LI->contains(inst->getParent())) - backwardsShadow = false; - } - } - - for (auto &&[floatTy_ref, seg_start_ref, seg_size_ref] : toIterate) { - auto floatTy = floatTy_ref; - auto seg_start = seg_start_ref; - auto seg_size = seg_size_ref; - - Value *length = new_size; - if (seg_start != std::get<1>(toIterate.back())) { - length = ConstantInt::get(new_size->getType(), seg_start + seg_size); - } - if (seg_start != 0) - length = BuilderZ.CreateSub( - length, ConstantInt::get(new_size->getType(), seg_start)); - - unsigned subdstalign = dstalign; - // todo make better alignment calculation - if (dstalign != 0) { - if (seg_start % dstalign != 0) { - dstalign = 1; - } - } - unsigned subsrcalign = srcalign; - // todo make better alignment calculation - if (srcalign != 0) { - if (seg_start % srcalign != 0) { - srcalign = 1; - } - } - IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&MTI)); - Value *shadow_dst = gutils->isConstantValue(orig_dst) - ? nullptr - : gutils->invertPointerM(orig_dst, BuilderZ); - Value *shadow_src = gutils->isConstantValue(orig_src) - ? nullptr - : gutils->invertPointerM(orig_src, BuilderZ); - - auto rev_rule = [&](Value *shadow_dst, Value *shadow_src) { - if (shadow_dst == nullptr) - shadow_dst = gutils->getNewFromOriginal(orig_dst); - if (shadow_src == nullptr) - shadow_src = gutils->getNewFromOriginal(orig_src); - SubTransferHelper( - gutils, Mode, floatTy, ID, subdstalign, subsrcalign, - /*offset*/ seg_start, gutils->isConstantValue(orig_dst), shadow_dst, - gutils->isConstantValue(orig_src), shadow_src, - /*length*/ length, /*volatile*/ isVolatile, &MTI, - /*allowForward*/ forwardsShadow, /*shadowsLookedup*/ false, - /*backwardsShadow*/ backwardsShadow); - }; - - auto fwd_rule = [&](Value *ddst, Value *dsrc) { - if (ddst == nullptr) - ddst = gutils->getNewFromOriginal(orig_dst); - if (dsrc == nullptr) - dsrc = gutils->getNewFromOriginal(orig_src); - MaybeAlign dalign; - if (subdstalign) - dalign = MaybeAlign(subdstalign); - MaybeAlign salign; - if (subsrcalign) - salign = MaybeAlign(subsrcalign); - if (ddst->getType()->isIntegerTy()) - ddst = - BuilderZ.CreateIntToPtr(ddst, getInt8PtrTy(ddst->getContext())); - if (seg_start != 0) { - ddst = BuilderZ.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(ddst->getContext()), ddst, seg_start); - } - CallInst *call; - // TODO add gutils->runtimeActivity (correctness) - if (floatTy && gutils->isConstantValue(orig_src)) { - call = BuilderZ.CreateMemSet( - ddst, ConstantInt::get(Type::getInt8Ty(ddst->getContext()), 0), - length, salign, isVolatile); - } else { - if (dsrc->getType()->isIntegerTy()) - dsrc = - BuilderZ.CreateIntToPtr(dsrc, getInt8PtrTy(dsrc->getContext())); - if (seg_start != 0) { - dsrc = BuilderZ.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(ddst->getContext()), dsrc, seg_start); - } - if (ID == Intrinsic::memmove) { - call = BuilderZ.CreateMemMove(ddst, dalign, dsrc, salign, length); - } else { - call = BuilderZ.CreateMemCpy(ddst, dalign, dsrc, salign, length); - } - call->setAttributes(MTI.getAttributes()); - } - // TODO shadow scope/noalias (performance) - call->setMetadata(LLVMContext::MD_alias_scope, - MTI.getMetadata(LLVMContext::MD_alias_scope)); - call->setMetadata(LLVMContext::MD_noalias, - MTI.getMetadata(LLVMContext::MD_noalias)); - call->setMetadata(LLVMContext::MD_tbaa, - MTI.getMetadata(LLVMContext::MD_tbaa)); - call->setMetadata(LLVMContext::MD_tbaa_struct, - MTI.getMetadata(LLVMContext::MD_tbaa_struct)); - call->setMetadata(LLVMContext::MD_invariant_group, - MTI.getMetadata(LLVMContext::MD_invariant_group)); - call->setTailCallKind(MTI.getTailCallKind()); - }; - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) - applyChainRule(BuilderZ, fwd_rule, shadow_dst, shadow_src); - else - applyChainRule(BuilderZ, rev_rule, shadow_dst, shadow_src); - } - - eraseIfUnused(MTI); - } - - void visitFenceInst(llvm::FenceInst &FI) { - using namespace llvm; - - switch (Mode) { - default: - break; - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: { - bool emitReverse = true; - if (EnzymeJuliaAddrLoad) { - if (auto prev = dyn_cast_or_null(FI.getPrevNode())) { - if (auto F = prev->getCalledFunction()) - if (F->getName() == "julia.safepoint") - emitReverse = false; - } - if (auto prev = dyn_cast_or_null(FI.getNextNode())) { - if (auto F = prev->getCalledFunction()) - if (F->getName() == "julia.safepoint") - emitReverse = false; - } - } - if (emitReverse) { - IRBuilder<> Builder2(&FI); - getReverseBuilder(Builder2); - auto order = FI.getOrdering(); - switch (order) { - case AtomicOrdering::Acquire: - order = AtomicOrdering::Release; - break; - case AtomicOrdering::Release: - order = AtomicOrdering::Acquire; - break; - default: - break; - } - Builder2.CreateFence(order, FI.getSyncScopeID()); - } - } - } - eraseIfUnused(FI); - } - - void visitIntrinsicInst(llvm::IntrinsicInst &II) { - using namespace llvm; - - if (II.getIntrinsicID() == Intrinsic::stacksave) { - eraseIfUnused(II, /*erase*/ true, /*check*/ false); - return; - } - if (II.getIntrinsicID() == Intrinsic::stackrestore || - II.getIntrinsicID() == Intrinsic::lifetime_end) { - eraseIfUnused(II, /*erase*/ true, /*check*/ false); - return; - } - - // When compiling Enzyme against standard LLVM, and not Intel's - // modified version of LLVM, the intrinsic `llvm.intel.subscript` is - // not fully understood by LLVM. One of the results of this is that the ID - // of the intrinsic is set to Intrinsic::not_intrinsic - hence we are - // handling the intrinsic here. - if (isIntelSubscriptIntrinsic(II)) { - if (Mode == DerivativeMode::ForwardModeSplit || - Mode == DerivativeMode::ForwardModeError || - Mode == DerivativeMode::ForwardMode) { - forwardModeInvertedPointerFallback(II); - } - } else { - SmallVector orig_ops(II.getNumOperands()); - - for (unsigned i = 0; i < II.getNumOperands(); ++i) { - orig_ops[i] = II.getOperand(i); - } - if (handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops)) - return; - } - if (gutils->knownRecomputeHeuristic.find(&II) != - gutils->knownRecomputeHeuristic.end()) { - if (!gutils->knownRecomputeHeuristic[&II]) { - CallInst *const newCall = - cast(gutils->getNewFromOriginal(&II)); - IRBuilder<> BuilderZ(newCall); - BuilderZ.setFastMathFlags(getFast()); - - gutils->cacheForReverse(BuilderZ, newCall, - getIndex(&II, CacheType::Self, BuilderZ)); - } - } - eraseIfUnused(II); - } - - bool - handleAdjointForIntrinsic(llvm::Intrinsic::ID ID, llvm::Instruction &I, - llvm::SmallVectorImpl &orig_ops) { - using namespace llvm; - - Module *M = I.getParent()->getParent()->getParent(); - - switch (ID) { -#if LLVM_VERSION_MAJOR < 20 - case Intrinsic::nvvm_ldg_global_i: - case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: -#endif - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: { - auto CI = cast(I.getOperand(1)); - visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue()), - /*constantval*/ false); - return false; - } - default: - break; - } - - if (ID == Intrinsic::masked_store) { - auto align0 = cast(I.getOperand(2))->getZExtValue(); - auto align = MaybeAlign(align0); - visitCommonStore(I, /*orig_ptr*/ I.getOperand(1), - /*orig_val*/ I.getOperand(0), align, - /*isVolatile*/ false, llvm::AtomicOrdering::NotAtomic, - SyncScope::SingleThread, - /*mask*/ gutils->getNewFromOriginal(I.getOperand(3))); - return false; - } - if (ID == Intrinsic::masked_load) { - auto align0 = cast(I.getOperand(1))->getZExtValue(); - auto align = MaybeAlign(align0); - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - bool constantval = parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); - visitLoadLike(I, align, constantval, - /*mask*/ gutils->getNewFromOriginal(I.getOperand(2)), - /*orig_maskInit*/ I.getOperand(3)); - return false; - } - - auto mod = I.getParent()->getParent()->getParent(); - auto called = cast(&I)->getCalledFunction(); - (void)called; - switch (ID) { -#include "IntrinsicDerivatives.inc" - default: - break; - } - - switch (Mode) { - case DerivativeMode::ReverseModePrimal: { - switch (ID) { - case Intrinsic::nvvm_barrier0: - case Intrinsic::nvvm_barrier0_popc: - case Intrinsic::nvvm_barrier0_and: - case Intrinsic::nvvm_barrier0_or: - case Intrinsic::nvvm_membar_cta: - case Intrinsic::nvvm_membar_gl: - case Intrinsic::nvvm_membar_sys: - case Intrinsic::amdgcn_s_barrier: - return false; - default: - if (gutils->isConstantInstruction(&I)) - return false; - if (ID == Intrinsic::umax || ID == Intrinsic::smax || - ID == Intrinsic::abs || ID == Intrinsic::sadd_with_overflow || - ID == Intrinsic::uadd_with_overflow || - ID == Intrinsic::smul_with_overflow || - ID == Intrinsic::umul_with_overflow || - ID == Intrinsic::ssub_with_overflow || - ID == Intrinsic::usub_with_overflow) - if (looseTypeAnalysis) { - EmitWarning("CannotDeduceType", I, - "failed to deduce type of intrinsic ", I); - return false; - } - std::string s; - llvm::raw_string_ostream ss(s); - ss << *gutils->oldFunc << "\n"; - ss << *gutils->newFunc << "\n"; - ss << "cannot handle (augmented) unknown intrinsic\n" << I; - IRBuilder<> BuilderZ(&I); - getForwardBuilder(BuilderZ); - EmitNoDerivativeError(ss.str(), I, gutils, BuilderZ); - return false; - } - return false; - } - - case DerivativeMode::ReverseModeCombined: - case DerivativeMode::ReverseModeGradient: { - - IRBuilder<> Builder2(&I); - getReverseBuilder(Builder2); - - Value *vdiff = nullptr; - if (!gutils->isConstantValue(&I)) { - vdiff = diffe(&I, Builder2); - setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())), - Builder2); - } - (void)vdiff; - - switch (ID) { - - case Intrinsic::nvvm_barrier0_popc: - case Intrinsic::nvvm_barrier0_and: - case Intrinsic::nvvm_barrier0_or: { - SmallVector args = {}; - auto cal = cast(Builder2.CreateCall( - getIntrinsicDeclaration(M, Intrinsic::nvvm_barrier0), args)); - cal->setCallingConv(getIntrinsicDeclaration(M, Intrinsic::nvvm_barrier0) - ->getCallingConv()); - cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc())); - return false; - } - - case Intrinsic::nvvm_barrier0: - case Intrinsic::amdgcn_s_barrier: - case Intrinsic::nvvm_membar_cta: - case Intrinsic::nvvm_membar_gl: - case Intrinsic::nvvm_membar_sys: { - SmallVector args = {}; - auto cal = cast( - Builder2.CreateCall(getIntrinsicDeclaration(M, ID), args)); - cal->setCallingConv(getIntrinsicDeclaration(M, ID)->getCallingConv()); - cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc())); - return false; - } - - case Intrinsic::lifetime_start: { - if (gutils->isConstantInstruction(&I)) - return false; - SmallVector args = { - lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2), - lookup(gutils->getNewFromOriginal(orig_ops[1]), Builder2)}; - Type *tys[] = {args[1]->getType()}; - auto cal = Builder2.CreateCall( - getIntrinsicDeclaration(M, Intrinsic::lifetime_end, tys), args); - cal->setCallingConv( - getIntrinsicDeclaration(M, Intrinsic::lifetime_end, tys) - ->getCallingConv()); - return false; - } - - case Intrinsic::vector_reduce_fmax: { - if (vdiff && !gutils->isConstantValue(orig_ops[0])) { - auto prev = lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2); - auto VT = cast(orig_ops[0]->getType()); - - assert(!VT->getElementCount().isScalable()); - size_t numElems = VT->getElementCount().getKnownMinValue(); - SmallVector elems; - SmallVector cmps; - - for (size_t i = 0; i < numElems; ++i) - elems.push_back(Builder2.CreateExtractElement(prev, (uint64_t)i)); - - Value *curmax = elems[0]; - for (size_t i = 0; i < numElems - 1; ++i) { - cmps.push_back(Builder2.CreateFCmpOLT(curmax, elems[i + 1])); - if (i + 2 != numElems) - curmax = CreateSelect(Builder2, cmps[i], elems[i + 1], curmax); - } - - auto rule = [&](Value *vdiff) { - auto nv = Constant::getNullValue(orig_ops[0]->getType()); - Value *res = Builder2.CreateInsertElement(nv, vdiff, (uint64_t)0); - - for (size_t i = 0; i < numElems - 1; ++i) { - auto rhs_v = Builder2.CreateInsertElement(nv, vdiff, i + 1); - res = CreateSelect(Builder2, cmps[i], rhs_v, res); - } - return res; - }; - Value *dif0 = - applyChainRule(orig_ops[0]->getType(), Builder2, rule, vdiff); - addToDiffe(orig_ops[0], dif0, Builder2, I.getType()); - } - return false; - } - default: - if (gutils->isConstantInstruction(&I)) - return false; - if (ID == Intrinsic::umax || ID == Intrinsic::smax || - ID == Intrinsic::abs || ID == Intrinsic::sadd_with_overflow || - ID == Intrinsic::uadd_with_overflow || - ID == Intrinsic::smul_with_overflow || - ID == Intrinsic::umul_with_overflow || - ID == Intrinsic::ssub_with_overflow || - ID == Intrinsic::usub_with_overflow) - if (looseTypeAnalysis) { - EmitWarning("CannotDeduceType", I, - "failed to deduce type of intrinsic ", I); - return false; - } - std::string s; - llvm::raw_string_ostream ss(s); - ss << *gutils->oldFunc << "\n"; - ss << *gutils->newFunc << "\n"; - if (Intrinsic::isOverloaded(ID)) - ss << "cannot handle (reverse) unknown intrinsic\n" - << Intrinsic::getName(ID, ArrayRef(), - gutils->oldFunc->getParent(), nullptr) - << "\n" - << I; - else - ss << "cannot handle (reverse) unknown intrinsic\n" - << Intrinsic::getName(ID) << "\n" - << I; - EmitNoDerivativeError(ss.str(), I, gutils, Builder2); - return false; - } - return false; - } - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardModeError: - case DerivativeMode::ForwardMode: { - - IRBuilder<> Builder2(&I); - getForwardBuilder(Builder2); - - switch (ID) { - - case Intrinsic::vector_reduce_fmax: { - if (gutils->isConstantInstruction(&I)) - return false; - auto prev = gutils->getNewFromOriginal(orig_ops[0]); - auto VT = cast(orig_ops[0]->getType()); - - assert(!VT->getElementCount().isScalable()); - size_t numElems = VT->getElementCount().getKnownMinValue(); - SmallVector elems; - SmallVector cmps; - - for (size_t i = 0; i < numElems; ++i) - elems.push_back(Builder2.CreateExtractElement(prev, (uint64_t)i)); - - Value *curmax = elems[0]; - for (size_t i = 0; i < numElems - 1; ++i) { - cmps.push_back(Builder2.CreateFCmpOLT(curmax, elems[i + 1])); - if (i + 2 != numElems) - curmax = CreateSelect(Builder2, cmps[i], elems[i + 1], curmax); - } - - auto rule = [&](Value *vdiff) { - Value *res = Builder2.CreateExtractElement(vdiff, (uint64_t)0); - - for (size_t i = 0; i < numElems - 1; ++i) { - auto rhs_v = Builder2.CreateExtractElement(vdiff, i + 1); - res = CreateSelect(Builder2, cmps[i], rhs_v, res); - } - return res; - }; - auto vdiff = diffe(orig_ops[0], Builder2); - - Value *dif = applyChainRule(I.getType(), Builder2, rule, vdiff); - setDiffe(&I, dif, Builder2); - return false; - } - default: - if (gutils->isConstantInstruction(&I)) - return false; - if (ID == Intrinsic::umax || ID == Intrinsic::smax || - ID == Intrinsic::abs || ID == Intrinsic::sadd_with_overflow || - ID == Intrinsic::uadd_with_overflow || - ID == Intrinsic::smul_with_overflow || - ID == Intrinsic::umul_with_overflow || - ID == Intrinsic::ssub_with_overflow || - ID == Intrinsic::usub_with_overflow) - if (looseTypeAnalysis) { - EmitWarning("CannotDeduceType", I, - "failed to deduce type of intrinsic ", I); - return false; - } - std::string s; - llvm::raw_string_ostream ss(s); - if (Intrinsic::isOverloaded(ID)) - ss << "cannot handle (forward) unknown intrinsic\n" - << Intrinsic::getName(ID, ArrayRef(), - gutils->oldFunc->getParent(), nullptr) - << "\n" - << I; - else - ss << "cannot handle (forward) unknown intrinsic\n" - << Intrinsic::getName(ID) << "\n" - << I; - EmitNoDerivativeError(ss.str(), I, gutils, Builder2); - if (!gutils->isConstantValue(&I)) - setDiffe(&I, - Constant::getNullValue(gutils->getShadowType(I.getType())), - Builder2); - return false; - } - return false; - } - } - - return false; - } - -// first one allows adding attributes to blas functions declared in the second -#include "BlasAttributor.inc" -#include "BlasDerivatives.inc" - - void visitOMPCall(llvm::CallInst &call) { - using namespace llvm; - - Function *kmpc = call.getCalledFunction(); - - if (overwritten_args_map.find(&call) == overwritten_args_map.end()) { - llvm::errs() << " call: " << call << "\n"; - for (auto &pair : overwritten_args_map) { - llvm::errs() << " + " << *pair.first << "\n"; - } - } - - auto found_ow = overwritten_args_map.find(&call); - assert(found_ow != overwritten_args_map.end()); - const bool subsequent_calls_may_write = found_ow->second.first; - const std::vector &overwritten_args = found_ow->second.second; - - IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call)); - BuilderZ.setFastMathFlags(getFast()); - - Function *task = dyn_cast(call.getArgOperand(2)); - if (task == nullptr && isa(call.getArgOperand(2))) { - task = dyn_cast( - cast(call.getArgOperand(2))->getOperand(0)); - } - if (task == nullptr) { - llvm::errs() << "could not derive underlying task from omp call: " << call - << "\n"; - llvm_unreachable("could not derive underlying task from omp call"); - } - if (task->empty()) { - llvm::errs() - << "could not derive underlying task contents from omp call: " << call - << "\n"; - llvm_unreachable( - "could not derive underlying task contents from omp call"); - } - - auto called = task; - // bool modifyPrimal = true; - - bool foreignFunction = called == nullptr; - - SmallVector args = {0, 0, 0}; - SmallVector pre_args = {0, 0, 0}; - std::vector argsInverted = {DIFFE_TYPE::CONSTANT, - DIFFE_TYPE::CONSTANT}; - SmallVector postCreate; - SmallVector userReplace; - - SmallVector OutTypes; - SmallVector OutFPTypes; - - for (unsigned i = 3; i < call.arg_size(); ++i) { - - auto argi = gutils->getNewFromOriginal(call.getArgOperand(i)); - - pre_args.push_back(argi); - - if (Mode != DerivativeMode::ReverseModePrimal) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - args.push_back(lookup(argi, Builder2)); - } - - auto argTy = gutils->getDiffeType(call.getArgOperand(i), foreignFunction); - argsInverted.push_back(argTy); - - if (argTy == DIFFE_TYPE::CONSTANT) { - continue; - } - - auto argType = argi->getType(); - - if (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED) { - if (Mode != DerivativeMode::ReverseModePrimal) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - args.push_back( - lookup(gutils->invertPointerM(call.getArgOperand(i), Builder2), - Builder2)); - } - pre_args.push_back( - gutils->invertPointerM(call.getArgOperand(i), BuilderZ)); - - // Note sometimes whattype mistakenly says something should be constant - // [because composed of integer pointers alone] - assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG || - whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); - } else { - assert(TR.query(call.getArgOperand(i))[{-1}].isFloat()); - OutTypes.push_back(call.getArgOperand(i)); - OutFPTypes.push_back(argType); - assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF || - whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); - } - } - - DIFFE_TYPE subretType = DIFFE_TYPE::CONSTANT; - - Value *tape = nullptr; - CallInst *augmentcall = nullptr; - // Value *cachereplace = nullptr; - - // TODO consider reduction of int 0 args - FnTypeInfo nextTypeInfo(called); - - if (called) { - std::map> intseen; - - TypeTree IntPtr; - IntPtr.insert({-1, -1}, BaseType::Integer); - IntPtr.insert({-1}, BaseType::Pointer); - - int argnum = 0; - for (auto &arg : called->args()) { - if (argnum <= 1) { - nextTypeInfo.Arguments.insert( - std::pair(&arg, IntPtr)); - nextTypeInfo.KnownValues.insert( - std::pair>(&arg, {0})); - } else { - nextTypeInfo.Arguments.insert(std::pair( - &arg, TR.query(call.getArgOperand(argnum - 2 + 3)))); - nextTypeInfo.KnownValues.insert( - std::pair>( - &arg, - TR.knownIntegralValues(call.getArgOperand(argnum - 2 + 3)))); - } - - ++argnum; - } - nextTypeInfo.Return = TR.query(&call); - } - - // std::optional, unsigned>> - // sub_index_map; - // Optional tapeIdx; - // Optional returnIdx; - // Optional differetIdx; - - const AugmentedReturn *subdata = nullptr; - if (Mode == DerivativeMode::ReverseModeGradient) { - assert(augmentedReturn); - if (augmentedReturn) { - auto fd = augmentedReturn->subaugmentations.find(&call); - if (fd != augmentedReturn->subaugmentations.end()) { - subdata = fd->second; - } - } - } - - if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined) { - if (called) { - subdata = &gutils->Logic.CreateAugmentedPrimal( - RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer->interprocedural, - /*return is used*/ false, - /*shadowReturnUsed*/ false, nextTypeInfo, - subsequent_calls_may_write, overwritten_args, false, - gutils->runtimeActivity, gutils->getWidth(), - /*AtomicAdd*/ true, - /*OpenMP*/ true); - if (Mode == DerivativeMode::ReverseModePrimal) { - assert(augmentedReturn); - auto subaugmentations = - (std::map - *)&augmentedReturn->subaugmentations; - insert_or_assign2( - *subaugmentations, &call, (AugmentedReturn *)subdata); - } - - assert(subdata); - auto newcalled = subdata->fn; - - if (subdata->returns.find(AugmentedStruct::Tape) != - subdata->returns.end()) { - ValueToValueMapTy VMap; - newcalled = CloneFunction(newcalled, VMap); - auto tapeArg = newcalled->arg_end(); - tapeArg--; - Type *tapeElemType = subdata->tapeType; - SmallVector, 4> geps; - SmallPtrSet gepsToErase; - for (auto a : tapeArg->users()) { - if (auto gep = dyn_cast(a)) { - auto idx = gep->idx_begin(); - idx++; - auto cidx = cast(idx->get()); - assert(gep->getNumIndices() == 2); - SmallPtrSet storesToErase; - for (auto st : gep->users()) { - auto SI = cast(st); - Value *op = SI->getValueOperand(); - storesToErase.insert(SI); - geps.emplace_back(cidx->getLimitedValue(), op); - } - for (auto SI : storesToErase) - SI->eraseFromParent(); - gepsToErase.insert(gep); - } else if (auto SI = dyn_cast(a)) { - Value *op = SI->getValueOperand(); - gepsToErase.insert(SI); - geps.emplace_back(-1, op); - } else { - llvm::errs() << "unknown tape user: " << a << "\n"; - assert(0 && "unknown tape user"); - llvm_unreachable("unknown tape user"); - } - } - for (auto gep : gepsToErase) - gep->eraseFromParent(); - IRBuilder<> ph(&*newcalled->getEntryBlock().begin()); - tape = UndefValue::get(tapeElemType); - ValueToValueMapTy available; - auto subarg = newcalled->arg_begin(); - subarg++; - subarg++; - for (size_t i = 3; i < pre_args.size(); ++i) { - available[&*subarg] = pre_args[i]; - subarg++; - } - for (auto pair : geps) { - Value *op = pair.second; - Value *alloc = op; - Value *replacement = gutils->unwrapM(op, BuilderZ, available, - UnwrapMode::LegalFullUnwrap); - tape = - pair.first == -1 - ? replacement - : BuilderZ.CreateInsertValue(tape, replacement, pair.first); - if (auto ci = dyn_cast(alloc)) { - alloc = ci->getOperand(0); - } - if (auto uload = dyn_cast(replacement)) { - gutils->unwrappedLoads.erase(uload); - if (auto ci = dyn_cast(replacement)) { - if (auto ucast = dyn_cast(ci->getOperand(0))) - gutils->unwrappedLoads.erase(ucast); - } - } - if (auto ci = dyn_cast(alloc)) { - if (auto F = ci->getCalledFunction()) { - // Store cached values - if (F->getName() == "malloc") { - const_cast(subdata) - ->tapeIndiciesToFree.emplace(pair.first); - Value *Idxs[] = { - ConstantInt::get(Type::getInt64Ty(tapeArg->getContext()), - 0), - ConstantInt::get(Type::getInt32Ty(tapeArg->getContext()), - pair.first)}; - op->replaceAllUsesWith(ph.CreateLoad( - op->getType(), - pair.first == -1 - ? tapeArg - : ph.CreateInBoundsGEP(tapeElemType, tapeArg, Idxs))); - cast(op)->eraseFromParent(); - if (op != alloc) - ci->eraseFromParent(); - continue; - } - } - } - Value *Idxs[] = { - ConstantInt::get(Type::getInt64Ty(tapeArg->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(tapeArg->getContext()), - pair.first)}; - op->replaceAllUsesWith(ph.CreateLoad( - op->getType(), - pair.first == -1 - ? tapeArg - : ph.CreateInBoundsGEP(tapeElemType, tapeArg, Idxs))); - cast(op)->eraseFromParent(); - } - assert(tape); - auto alloc = - IRBuilder<>(gutils->inversionAllocs).CreateAlloca(tapeElemType); - BuilderZ.CreateStore(tape, alloc); - pre_args.push_back(alloc); - assert(tape); - gutils->cacheForReverse(BuilderZ, tape, - getIndex(&call, CacheType::Tape, BuilderZ)); - } - - auto numargs = ConstantInt::get(Type::getInt32Ty(call.getContext()), - pre_args.size() - 3); - pre_args[0] = gutils->getNewFromOriginal(call.getArgOperand(0)); - pre_args[1] = numargs; - pre_args[2] = BuilderZ.CreatePointerCast( - newcalled, kmpc->getFunctionType()->getParamType(2)); - augmentcall = - BuilderZ.CreateCall(kmpc->getFunctionType(), kmpc, pre_args); - augmentcall->setCallingConv(call.getCallingConv()); - augmentcall->setDebugLoc( - gutils->getNewFromOriginal(call.getDebugLoc())); - BuilderZ.SetInsertPoint( - gutils->getNewFromOriginal(&call)->getNextNode()); - gutils->erase(gutils->getNewFromOriginal(&call)); - } else { - assert(0 && "unhandled unknown outline"); - } - } - - { - Intrinsic::ID ID = Intrinsic::not_intrinsic; - if (!subdata && !isMemFreeLibMFunction(getFuncNameFromCall(&call), &ID)) { - llvm::errs() << *gutils->oldFunc->getParent() << "\n"; - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - llvm::errs() << *called << "\n"; - llvm_unreachable("no subdata"); - } - } - - if (subdata) { - auto found = subdata->returns.find(AugmentedStruct::DifferentialReturn); - assert(found == subdata->returns.end()); - } - if (subdata) { - auto found = subdata->returns.find(AugmentedStruct::Return); - assert(found == subdata->returns.end()); - } - - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - if (Mode == DerivativeMode::ReverseModeGradient) { - BuilderZ.SetInsertPoint( - gutils->getNewFromOriginal(&call)->getNextNode()); - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } - - Function *newcalled = nullptr; - if (called) { - if (subdata && subdata->returns.find(AugmentedStruct::Tape) != - subdata->returns.end()) { - if (Mode == DerivativeMode::ReverseModeGradient) { - if (tape == nullptr) { -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - tape = BuilderZ.CreatePHI(subdata->tapeType, 0, "tapeArg"); - } - tape = gutils->cacheForReverse( - BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ)); - } - tape = lookup(tape, Builder2); - auto alloc = IRBuilder<>(gutils->inversionAllocs) - .CreateAlloca(tape->getType()); - Builder2.CreateStore(tape, alloc); - args.push_back(alloc); - } - - if (Mode == DerivativeMode::ReverseModeGradient && subdata) { - for (size_t i = 0; i < argsInverted.size(); i++) { - if (subdata->constant_args[i] == argsInverted[i]) - continue; - assert(subdata->constant_args[i] == DIFFE_TYPE::DUP_ARG); - assert(argsInverted[i] == DIFFE_TYPE::DUP_NONEED); - argsInverted[i] = DIFFE_TYPE::DUP_ARG; - } - } - - newcalled = gutils->Logic.CreatePrimalAndGradient( - RequestContext(&call, &Builder2), - (ReverseCacheKey){ - .todiff = cast(called), - .retType = subretType, - .constant_args = argsInverted, - .subsequent_calls_may_write = subsequent_calls_may_write, - .overwritten_args = overwritten_args, - .returnUsed = false, - .shadowReturnUsed = false, - .mode = DerivativeMode::ReverseModeGradient, - .width = gutils->getWidth(), - .freeMemory = true, - .AtomicAdd = true, - .additionalType = - tape ? PointerType::getUnqual(tape->getType()) : nullptr, - .forceAnonymousTape = false, - .typeInfo = nextTypeInfo, - .runtimeActivity = gutils->runtimeActivity, - }, - TR.analyzer->interprocedural, subdata, - /*omp*/ true); - - if (subdata && subdata->returns.find(AugmentedStruct::Tape) != - subdata->returns.end()) { - auto tapeArg = newcalled->arg_end(); - tapeArg--; - LoadInst *tape = nullptr; - for (auto u : tapeArg->users()) { - assert(!tape); - if (!isa(u)) { - llvm::errs() << " newcalled: " << *newcalled << "\n"; - llvm::errs() << " u: " << *u << "\n"; - } - tape = cast(u); - } - assert(tape); - SmallVector extracts; - if (subdata->tapeIndices.size() == 1) { - assert(subdata->tapeIndices.begin()->second == -1); - extracts.push_back(tape); - } else { - for (auto a : tape->users()) { - extracts.push_back(a); - } - } - SmallVector geps; - for (auto E : extracts) { - AllocaInst *AI = nullptr; - for (auto U : E->users()) { - if (auto SI = dyn_cast(U)) { - assert(SI->getValueOperand() == E); - AI = cast(SI->getPointerOperand()); - } - } - if (AI) { - for (auto U : AI->users()) { - if (auto LI = dyn_cast(U)) { - geps.push_back(LI); - } - } - } - } - for (auto LI : geps) { - CallInst *freeCall = nullptr; - for (auto LU : LI->users()) { - if (auto CI = dyn_cast(LU)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "free") { - freeCall = CI; - break; - } - } - } else if (auto BC = dyn_cast(LU)) { - for (auto CU : BC->users()) { - if (auto CI = dyn_cast(CU)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "free") { - freeCall = CI; - break; - } - } - } - } - if (freeCall) - break; - } - } - if (freeCall) { - freeCall->eraseFromParent(); - } - } - } - - Value *OutAlloc = nullptr; - auto ST = StructType::get(newcalled->getContext(), OutFPTypes); - if (OutTypes.size()) { - OutAlloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(ST); - args.push_back(OutAlloc); - - SmallVector MetaTypes; - for (auto P : - cast(newcalled)->getFunctionType()->params()) { - MetaTypes.push_back(P); - } - MetaTypes.push_back(PointerType::getUnqual(ST)); - auto FT = FunctionType::get(Type::getVoidTy(newcalled->getContext()), - MetaTypes, false); - Function *F = - Function::Create(FT, GlobalVariable::InternalLinkage, - cast(newcalled)->getName() + "#out", - *task->getParent()); - BasicBlock *entry = - BasicBlock::Create(newcalled->getContext(), "entry", F); - IRBuilder<> B(entry); - SmallVector SubArgs; - for (auto &arg : F->args()) - SubArgs.push_back(&arg); - Value *cacheArg = SubArgs.back(); - SubArgs.pop_back(); - Value *outdiff = B.CreateCall(newcalled, SubArgs); - for (size_t ee = 0; ee < OutTypes.size(); ee++) { - Value *dif = B.CreateExtractValue(outdiff, ee); - Value *Idxs[] = { - ConstantInt::get(Type::getInt64Ty(ST->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(ST->getContext()), ee)}; - Value *ptr = B.CreateInBoundsGEP(ST, cacheArg, Idxs); - - if (dif->getType()->isIntOrIntVectorTy()) { - - ptr = B.CreateBitCast( - ptr, - PointerType::get( - IntToFloatTy(dif->getType()), - cast(ptr->getType())->getAddressSpace())); - dif = B.CreateBitCast(dif, IntToFloatTy(dif->getType())); - } - - MaybeAlign align; - AtomicRMWInst::BinOp op = AtomicRMWInst::FAdd; - if (auto vt = dyn_cast(dif->getType())) { - assert(!vt->getElementCount().isScalable()); - size_t numElems = vt->getElementCount().getKnownMinValue(); - for (size_t i = 0; i < numElems; ++i) { - auto vdif = B.CreateExtractElement(dif, i); - Value *Idxs[] = { - ConstantInt::get(Type::getInt64Ty(vt->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(vt->getContext()), i)}; - auto vptr = B.CreateInBoundsGEP(vt, ptr, Idxs); - B.CreateAtomicRMW(op, vptr, vdif, align, - AtomicOrdering::Monotonic, SyncScope::System); - } - } else { - B.CreateAtomicRMW(op, ptr, dif, align, AtomicOrdering::Monotonic, - SyncScope::System); - } - } - B.CreateRetVoid(); - newcalled = F; - } - - auto numargs = ConstantInt::get(Type::getInt32Ty(call.getContext()), - args.size() - 3); - args[0] = - lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2); - args[1] = numargs; - args[2] = Builder2.CreatePointerCast( - newcalled, kmpc->getFunctionType()->getParamType(2)); - - CallInst *diffes = - Builder2.CreateCall(kmpc->getFunctionType(), kmpc, args); - diffes->setCallingConv(call.getCallingConv()); - diffes->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc())); - - for (size_t i = 0; i < OutTypes.size(); i++) { - - size_t size = 1; - if (OutTypes[i]->getType()->isSized()) - size = (gutils->newFunc->getParent() - ->getDataLayout() - .getTypeSizeInBits(OutTypes[i]->getType()) + - 7) / - 8; - Value *Idxs[] = { - ConstantInt::get(Type::getInt64Ty(call.getContext()), 0), - ConstantInt::get(Type::getInt32Ty(call.getContext()), i)}; - ((DiffeGradientUtils *)gutils) - ->addToDiffe(OutTypes[i], - Builder2.CreateLoad( - OutFPTypes[i], - Builder2.CreateInBoundsGEP(ST, OutAlloc, Idxs)), - Builder2, TR.addingType(size, OutTypes[i])); - } - - if (tape && shouldFree()) { - for (auto idx : subdata->tapeIndiciesToFree) { - CreateDealloc(Builder2, - idx == -1 ? tape - : Builder2.CreateExtractValue(tape, idx)); - } - } - } else { - assert(0 && "openmp indirect unhandled"); - } - } - } - - void DifferentiableMemCopyFloats( - llvm::CallInst &call, llvm::Value *origArg, llvm::Value *dsto, - llvm::Value *srco, llvm::Value *len_arg, llvm::IRBuilder<> &Builder2, - llvm::ArrayRef ReverseDefs) { - using namespace llvm; - - size_t size = 1; - if (auto ci = dyn_cast(len_arg)) { - size = ci->getLimitedValue(); - } - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - auto vd = TR.query(origArg).Data0().ShiftIndices(DL, 0, size, 0); - if (!vd.isKnownPastPointer()) { -#if LLVM_VERSION_MAJOR < 17 - if (looseTypeAnalysis) { - if (isa(origArg) && - cast(origArg)->getSrcTy()->isPointerTy() && - cast(origArg) - ->getSrcTy() - ->getPointerElementType() - ->isFPOrFPVectorTy()) { - vd = TypeTree(ConcreteType(cast(origArg) - ->getSrcTy() - ->getPointerElementType() - ->getScalarType())) - .Only(0, &call); - goto knownF; - } - } -#endif - TR.dump(); - EmitFailure("CannotDeduceType", call.getDebugLoc(), &call, - "failed to deduce type of copy ", call); - } -#if LLVM_VERSION_MAJOR < 17 - knownF: -#endif - unsigned start = 0; - while (1) { - unsigned nextStart = size; - - auto dt = vd[{-1}]; - for (size_t i = start; i < size; ++i) { - bool Legal = true; - dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal); - if (!Legal) { - nextStart = i; - break; - } - } - if (!dt.isKnown()) { - TR.dump(); - llvm::errs() << " vd:" << vd.str() << " start:" << start - << " size: " << size << " dt:" << dt.str() << "\n"; - } - assert(dt.isKnown()); - - Value *length = len_arg; - if (nextStart != size) { - length = ConstantInt::get(len_arg->getType(), nextStart); - } - if (start != 0) - length = Builder2.CreateSub( - length, ConstantInt::get(len_arg->getType(), start)); - - if (auto secretty = dt.isFloat()) { - auto offset = start; - if (dsto->getType()->isIntegerTy()) - dsto = - Builder2.CreateIntToPtr(dsto, getInt8PtrTy(dsto->getContext())); - unsigned dstaddr = - cast(dsto->getType())->getAddressSpace(); - auto secretpt = PointerType::get(secretty, dstaddr); - if (offset != 0) { - dsto = Builder2.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(dsto->getContext()), dsto, offset); - } - if (srco->getType()->isIntegerTy()) - srco = - Builder2.CreateIntToPtr(srco, getInt8PtrTy(dsto->getContext())); - unsigned srcaddr = - cast(srco->getType())->getAddressSpace(); - secretpt = PointerType::get(secretty, srcaddr); - - if (offset != 0) { - srco = Builder2.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(srco->getContext()), srco, offset); - } - Value *args[3] = { - Builder2.CreatePointerCast(dsto, secretpt), - Builder2.CreatePointerCast(srco, secretpt), - Builder2.CreateUDiv( - length, - - ConstantInt::get(length->getType(), - Builder2.GetInsertBlock() - ->getParent() - ->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(secretty) / - 8))}; - - auto dmemcpy = getOrInsertDifferentialFloatMemcpy( - *Builder2.GetInsertBlock()->getParent()->getParent(), secretty, - /*dstalign*/ 1, /*srcalign*/ 1, dstaddr, srcaddr, - cast(length->getType())->getBitWidth()); - - Builder2.CreateCall(dmemcpy, args, ReverseDefs); - } - - if (nextStart == size) - break; - start = nextStart; - } - } - - void recursivelyHandleSubfunction(llvm::CallInst &call, - llvm::Function *called, - bool subsequent_calls_may_write, - const std::vector &overwritten_args, - bool shadowReturnUsed, - DIFFE_TYPE subretType, bool subretused) { - using namespace llvm; - - IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call)); - BuilderZ.setFastMathFlags(getFast()); - - CallInst *newCall = cast(gutils->getNewFromOriginal(&call)); - Module &M = *call.getParent()->getParent()->getParent(); - - bool foreignFunction = called == nullptr; - - FnTypeInfo nextTypeInfo(called); - - if (called) { - nextTypeInfo = TR.getCallInfo(call, *called); - } - - const AugmentedReturn *subdata = nullptr; - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit) { - assert(augmentedReturn); - if (augmentedReturn) { - auto fd = augmentedReturn->subaugmentations.find(&call); - if (fd != augmentedReturn->subaugmentations.end()) { - subdata = fd->second; - } - } - } - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError || - Mode == DerivativeMode::ForwardModeSplit) { - IRBuilder<> Builder2(&call); - getForwardBuilder(Builder2); - - SmallVector args; - std::vector argsInverted; - std::map gradByVal; - std::map> structAttrs; - - for (unsigned i = 0; i < call.arg_size(); ++i) { - - if (call.paramHasAttr(i, Attribute::StructRet)) { - structAttrs[args.size()].push_back( - Attribute::get(call.getContext(), "enzyme_sret")); - // TODO - // structAttrs[args.size()].push_back(Attribute::get( - // call.getContext(), Attribute::AttrKind::ElementType, - // call.getParamAttr(i, Attribute::StructRet).getValueAsType())); - } - for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", - "enzymejl_parmtype_ref", "enzyme_type"}) - if (call.getAttributes().hasParamAttr(i, attr)) { - structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); - } - for (auto ty : PrimalParamAttrsToPreserve) - if (call.getAttributes().hasParamAttr(i, ty)) { - auto attr = call.getAttributes().getParamAttr(i, ty); - structAttrs[args.size()].push_back(attr); - } - - auto argi = gutils->getNewFromOriginal(call.getArgOperand(i)); - - if (call.isByValArgument(i)) { - gradByVal[args.size()] = call.getParamByValType(i); - } - - bool writeOnlyNoCapture = true; - bool readOnly = true; - if (!isNoCapture(&call, i)) { - writeOnlyNoCapture = false; - } - if (!isWriteOnly(&call, i)) { - writeOnlyNoCapture = false; - } - if (!isReadOnly(&call, i)) { - readOnly = false; - } - - if (shouldDisableNoWrite(&call)) - writeOnlyNoCapture = false; - - auto argTy = - gutils->getDiffeType(call.getArgOperand(i), foreignFunction); - - bool replace = - (argTy == DIFFE_TYPE::DUP_NONEED && - (writeOnlyNoCapture || - !isa(getBaseObject(call.getArgOperand(i))))) || - (writeOnlyNoCapture && Mode == DerivativeMode::ForwardModeSplit) || - (writeOnlyNoCapture && readOnly); - - if (replace) { - argi = getUndefinedValueForType(M, argi->getType()); - } - argsInverted.push_back(argTy); - args.push_back(argi); - - if (argTy == DIFFE_TYPE::CONSTANT) { - continue; - } - - if (gutils->getWidth() == 1) - for (auto ty : ShadowParamAttrsToPreserve) - if (call.getAttributes().hasParamAttr(i, ty)) { - auto attr = call.getAttributes().getParamAttr(i, ty); - structAttrs[args.size()].push_back(attr); - } - - for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", - "enzymejl_parmtype_ref", "enzyme_type"}) - if (call.getAttributes().hasParamAttr(i, attr)) { - if (gutils->getWidth() == 1) { - structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); - } else if (attr == std::string("enzymejl_returnRoots")) { - structAttrs[args.size()].push_back( - Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); - } - } - if (call.paramHasAttr(i, Attribute::StructRet)) { - if (gutils->getWidth() == 1) { - structAttrs[args.size()].push_back( - Attribute::get(call.getContext(), "enzyme_sret") - // orig->getParamAttr(i, - // Attribute::StructRet).getValueAsType()); - ); - // TODO - // structAttrs[args.size()].push_back(Attribute::get( - // call.getContext(), Attribute::AttrKind::ElementType, - // call.getParamAttr(i, - // Attribute::StructRet).getValueAsType())); - } else { - structAttrs[args.size()].push_back( - Attribute::get(call.getContext(), "enzyme_sret_v")); - // TODO - // structAttrs[args.size()].push_back(Attribute::get( - // call.getContext(), Attribute::AttrKind::ElementType, - // call.getParamAttr(i, - // Attribute::StructRet).getValueAsType())); - } - } - - assert(argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED); - - args.push_back(gutils->invertPointerM(call.getArgOperand(i), Builder2)); - } -#if LLVM_VERSION_MAJOR >= 16 - std::optional tapeIdx; -#else - Optional tapeIdx; -#endif - if (subdata) { - auto found = subdata->returns.find(AugmentedStruct::Tape); - if (found != subdata->returns.end()) { - tapeIdx = found->second; - } - } - Value *tape = nullptr; - if (tapeIdx) { - - auto idx = *tapeIdx; - FunctionType *FT = subdata->fn->getFunctionType(); -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - tape = BuilderZ.CreatePHI( - (tapeIdx == -1) - ? FT->getReturnType() - : cast(FT->getReturnType())->getElementType(idx), - 1, "tapeArg"); - - assert(!tape->getType()->isEmptyTy()); - gutils->TapesToPreventRecomputation.insert(cast(tape)); - tape = gutils->cacheForReverse( - BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ)); - args.push_back(tape); - } - - Value *newcalled = nullptr; - FunctionType *FT = nullptr; - - if (called) { - newcalled = gutils->Logic.CreateForwardDiff( - RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer->interprocedural, - /*returnValue*/ subretused, Mode, - ((DiffeGradientUtils *)gutils)->FreeMemory, gutils->runtimeActivity, - gutils->getWidth(), tape ? tape->getType() : nullptr, nextTypeInfo, - subsequent_calls_may_write, overwritten_args, - /*augmented*/ subdata); - FT = cast(newcalled)->getFunctionType(); - } else { - auto callval = call.getCalledOperand(); - newcalled = gutils->invertPointerM(callval, BuilderZ); - - if (gutils->getWidth() > 1) { - newcalled = BuilderZ.CreateExtractValue(newcalled, {0}); - } - - ErrorIfRuntimeInactive( - BuilderZ, gutils->getNewFromOriginal(callval), newcalled, - "Attempting to call an indirect active function " - "whose runtime value is inactive", - gutils->getNewFromOriginal(call.getDebugLoc()), &call); - - auto ft = call.getFunctionType(); - bool retActive = subretType != DIFFE_TYPE::CONSTANT; - - ReturnType subretVal = - subretused - ? (retActive ? ReturnType::TwoReturns : ReturnType::Return) - : (retActive ? ReturnType::Return : ReturnType::Void); - - FT = getFunctionTypeForClone( - ft, Mode, gutils->getWidth(), tape ? tape->getType() : nullptr, - argsInverted, false, subretVal, subretType); - PointerType *fptype = PointerType::getUnqual(FT); - newcalled = BuilderZ.CreatePointerCast(newcalled, - PointerType::getUnqual(fptype)); - newcalled = BuilderZ.CreateLoad(fptype, newcalled); - } - - assert(newcalled); - assert(FT); - - SmallVector BundleTypes; - for (auto A : argsInverted) - if (A == DIFFE_TYPE::CONSTANT) - BundleTypes.push_back(ValueType::Primal); - else - BundleTypes.push_back(ValueType::Both); - - auto Defs = gutils->getInvertedBundles(&call, BundleTypes, Builder2, - /*lookup*/ false); - - CallInst *diffes = Builder2.CreateCall(FT, newcalled, args, Defs); - diffes->setCallingConv(call.getCallingConv()); - diffes->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc())); - - for (auto pair : gradByVal) { - diffes->addParamAttr( - pair.first, - Attribute::getWithByValType(diffes->getContext(), pair.second)); - } - - for (auto &pair : structAttrs) { - for (auto val : pair.second) - diffes->addParamAttr(pair.first, val); - } - - auto newcall = gutils->getNewFromOriginal(&call); - auto ifound = gutils->invertedPointers.find(&call); - Value *primal = nullptr; - Value *diffe = nullptr; - - if (subretused && subretType != DIFFE_TYPE::CONSTANT) { - primal = Builder2.CreateExtractValue(diffes, 0); - diffe = Builder2.CreateExtractValue(diffes, 1); - } else if (subretType != DIFFE_TYPE::CONSTANT) { - diffe = diffes; - } else if (!FT->getReturnType()->isVoidTy()) { - primal = diffes; - } - - if (ifound != gutils->invertedPointers.end()) { - auto placeholder = cast(&*ifound->second); - if (primal) { - gutils->replaceAWithB(newcall, primal); - gutils->erase(newcall); - } else { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } - if (diffe) { - gutils->replaceAWithB(placeholder, diffe); - } else { - gutils->invertedPointers.erase(ifound); - } - gutils->erase(placeholder); - } else { - if (primal && diffe) { - gutils->replaceAWithB(newcall, primal); - if (!gutils->isConstantValue(&call)) { - setDiffe(&call, diffe, Builder2); - } - gutils->erase(newcall); - } else if (diffe) { - setDiffe(&call, diffe, Builder2); - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } else if (primal) { - gutils->replaceAWithB(newcall, primal); - gutils->erase(newcall); - } else { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } - } - - return; - } - - bool modifyPrimal = shouldAugmentCall(&call, gutils); - - SmallVector args; - SmallVector pre_args; - std::vector argsInverted; - SmallVector postCreate; - SmallVector userReplace; - std::map preByVal; - std::map gradByVal; - std::map> structAttrs; - - bool replaceFunction = false; - - if (Mode == DerivativeMode::ReverseModeCombined && !foreignFunction) { - replaceFunction = legalCombinedForwardReverse( - &call, *replacedReturns, postCreate, userReplace, gutils, - unnecessaryInstructions, oldUnreachable, subretused); - if (replaceFunction) { - modifyPrimal = false; - } - } - - SmallVector PreBundleTypes; - SmallVector BundleTypes; - - for (unsigned i = 0; i < call.arg_size(); ++i) { - - auto argi = gutils->getNewFromOriginal(call.getArgOperand(i)); - - if (call.isByValArgument(i)) { - preByVal[pre_args.size()] = call.getParamByValType(i); - } - for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", - "enzymejl_parmtype_ref", "enzyme_type"}) - if (call.getAttributes().hasParamAttr(i, attr)) { - structAttrs[pre_args.size()].push_back(call.getParamAttr(i, attr)); - } - if (call.paramHasAttr(i, Attribute::StructRet)) { - structAttrs[pre_args.size()].push_back( - // TODO persist types - Attribute::get(call.getContext(), "enzyme_sret") - // Attribute::get(orig->getContext(), "enzyme_sret", - // orig->getParamAttr(ii, Attribute::StructRet).getValueAsType()); - ); - } - for (auto ty : PrimalParamAttrsToPreserve) - if (call.getAttributes().hasParamAttr(i, ty)) { - auto attr = call.getAttributes().getParamAttr(i, ty); - structAttrs[pre_args.size()].push_back(attr); - } - - auto argTy = gutils->getDiffeType(call.getArgOperand(i), foreignFunction); - - bool writeOnlyNoCapture = true; - bool readNoneNoCapture = false; - if (!isNoCapture(&call, i)) { - writeOnlyNoCapture = false; - readNoneNoCapture = false; - } - if (!isWriteOnly(&call, i)) { - writeOnlyNoCapture = false; - } - if (!(isReadOnly(&call, i) && isWriteOnly(&call, i))) { - readNoneNoCapture = false; - } - - if (shouldDisableNoWrite(&call)) { - writeOnlyNoCapture = false; - readNoneNoCapture = false; - } - - Value *prearg = argi; - - ValueType preType = ValueType::Primal; - ValueType revType = ValueType::Primal; - - // Keep the existing passed value if coming from outside. - if (readNoneNoCapture || - (argTy == DIFFE_TYPE::DUP_NONEED && - (writeOnlyNoCapture || - !isa(getBaseObject(call.getArgOperand(i)))))) { - prearg = getUndefinedValueForType(M, argi->getType()); - preType = ValueType::None; - } - pre_args.push_back(prearg); - - if (Mode != DerivativeMode::ReverseModePrimal) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - if (call.isByValArgument(i)) { - gradByVal[args.size()] = call.getParamByValType(i); - } - - if ((writeOnlyNoCapture && !replaceFunction) || - (readNoneNoCapture || - (argTy == DIFFE_TYPE::DUP_NONEED && - (writeOnlyNoCapture || - !isa(getBaseObject(call.getOperand(i))))))) { - argi = getUndefinedValueForType(M, argi->getType()); - revType = ValueType::None; - } - args.push_back(lookup(argi, Builder2)); - } - - argsInverted.push_back(argTy); - - if (argTy == DIFFE_TYPE::CONSTANT) { - PreBundleTypes.push_back(preType); - BundleTypes.push_back(revType); - continue; - } - - auto argType = argi->getType(); - - if (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED) { - if (gutils->getWidth() == 1) - for (auto ty : ShadowParamAttrsToPreserve) - if (call.getAttributes().hasParamAttr(i, ty)) { - auto attr = call.getAttributes().getParamAttr(i, ty); - structAttrs[pre_args.size()].push_back(attr); - } - - for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", - "enzymejl_parmtype_ref", "enzyme_type"}) - if (call.getAttributes().hasParamAttr(i, attr)) { - if (gutils->getWidth() == 1) { - structAttrs[pre_args.size()].push_back( - call.getParamAttr(i, attr)); - } else if (attr == std::string("enzymejl_returnRoots")) { - structAttrs[pre_args.size()].push_back( - Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); - } - } - if (call.paramHasAttr(i, Attribute::StructRet)) { - if (gutils->getWidth() == 1) { - structAttrs[pre_args.size()].push_back( - // TODO persist types - Attribute::get(call.getContext(), "enzyme_sret") - // Attribute::get(orig->getContext(), "enzyme_sret", - // orig->getParamAttr(ii, - // Attribute::StructRet).getValueAsType()); - ); - } else { - structAttrs[pre_args.size()].push_back( - // TODO persist types - Attribute::get(call.getContext(), "enzyme_sret_v") - // Attribute::get(orig->getContext(), "enzyme_sret_v", - // gutils->getShadowType(orig->getParamAttr(ii, - // Attribute::StructRet).getValueAsType())); - ); - } - } - if (Mode != DerivativeMode::ReverseModePrimal) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - Value *darg = nullptr; - - if (writeOnlyNoCapture && !replaceFunction && - TR.query(call.getArgOperand(i))[{-1, -1}] == BaseType::Pointer) { - darg = getUndefinedValueForType(M, argi->getType()); - } else { - darg = gutils->invertPointerM(call.getArgOperand(i), Builder2); - revType = (revType == ValueType::None) ? ValueType::Shadow - : ValueType::Both; - } - args.push_back(lookup(darg, Builder2)); - } - pre_args.push_back( - gutils->invertPointerM(call.getArgOperand(i), BuilderZ)); - preType = - (preType == ValueType::None) ? ValueType::Shadow : ValueType::Both; - - // Note sometimes whattype mistakenly says something should be - // constant [because composed of integer pointers alone] - (void)argType; - assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG || - whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); - } else { - if (foreignFunction) - assert(!argType->isIntOrIntVectorTy()); - assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF || - whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); - } - PreBundleTypes.push_back(preType); - BundleTypes.push_back(revType); - } - if (called) { - if (call.arg_size() != - cast(called)->getFunctionType()->getNumParams()) { - llvm::errs() << *gutils->oldFunc->getParent() << "\n"; - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << call << "\n"; - llvm::errs() << " number of arg operands != function parameters\n"; - EmitFailure("MismatchArgs", call.getDebugLoc(), &call, - "Number of arg operands != function parameters\n", call); - } - } - - Value *tape = nullptr; - CallInst *augmentcall = nullptr; - Value *cachereplace = nullptr; - - // std::optional, - // unsigned>> sub_index_map; -#if LLVM_VERSION_MAJOR >= 16 - std::optional tapeIdx; - std::optional returnIdx; - std::optional differetIdx; -#else - Optional tapeIdx; - Optional returnIdx; - Optional differetIdx; -#endif - if (modifyPrimal) { - - Value *newcalled = nullptr; - FunctionType *FT = nullptr; - const AugmentedReturn *fnandtapetype = nullptr; - - if (!called) { - auto callval = call.getCalledOperand(); - Value *uncast = callval; - while (auto CE = dyn_cast(uncast)) { - if (CE->isCast()) { - uncast = CE->getOperand(0); - continue; - } - break; - } - if (isa(uncast)) { - std::string str; - raw_string_ostream ss(str); - ss << "cannot find shadow for " << *callval - << " for use as function in " << call; - EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ); - } - newcalled = gutils->invertPointerM(callval, BuilderZ); - - if (Mode != DerivativeMode::ReverseModeGradient) - ErrorIfRuntimeInactive( - BuilderZ, gutils->getNewFromOriginal(callval), newcalled, - "Attempting to call an indirect active function " - "whose runtime value is inactive", - gutils->getNewFromOriginal(call.getDebugLoc()), &call); - - FunctionType *ft = call.getFunctionType(); - - std::set seen; - DIFFE_TYPE subretType = whatType(call.getType(), Mode, - /*intAreConstant*/ false, seen); - auto res = getDefaultFunctionTypeForAugmentation( - ft, /*returnUsed*/ true, /*subretType*/ subretType); - FT = FunctionType::get( - StructType::get(newcalled->getContext(), res.second), res.first, - ft->isVarArg()); - auto fptype = PointerType::getUnqual(FT); - newcalled = BuilderZ.CreatePointerCast(newcalled, - PointerType::getUnqual(fptype)); - newcalled = BuilderZ.CreateLoad(fptype, newcalled); - tapeIdx = 0; - - if (!call.getType()->isVoidTy()) { - returnIdx = 1; - if (subretType == DIFFE_TYPE::DUP_ARG || - subretType == DIFFE_TYPE::DUP_NONEED) { - differetIdx = 2; - } - } - } else { - if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined) { - subdata = &gutils->Logic.CreateAugmentedPrimal( - RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer->interprocedural, - /*return is used*/ subretused, shadowReturnUsed, nextTypeInfo, - subsequent_calls_may_write, overwritten_args, false, - gutils->runtimeActivity, gutils->getWidth(), gutils->AtomicAdd); - if (Mode == DerivativeMode::ReverseModePrimal) { - assert(augmentedReturn); - auto subaugmentations = - (std::map - *)&augmentedReturn->subaugmentations; - insert_or_assign2( - *subaugmentations, &call, (AugmentedReturn *)subdata); - } - } - { - Intrinsic::ID ID = Intrinsic::not_intrinsic; - if (!subdata && - !isMemFreeLibMFunction(getFuncNameFromCall(&call), &ID)) { - llvm::errs() << *gutils->oldFunc->getParent() << "\n"; - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - llvm::errs() << *called << "\n"; - assert(subdata); - } - } - - if (subdata) { - fnandtapetype = subdata; - newcalled = subdata->fn; - FT = cast(newcalled)->getFunctionType(); - - auto found = - subdata->returns.find(AugmentedStruct::DifferentialReturn); - if (found != subdata->returns.end()) { - differetIdx = found->second; - } else { - assert(!shadowReturnUsed); - } - - found = subdata->returns.find(AugmentedStruct::Return); - if (found != subdata->returns.end()) { - returnIdx = found->second; - } else { - assert(!subretused); - } - - found = subdata->returns.find(AugmentedStruct::Tape); - if (found != subdata->returns.end()) { - tapeIdx = found->second; - } - } - } - // sub_index_map = fnandtapetype.tapeIndices; - - // llvm::errs() << "seeing sub_index_map of " << sub_index_map->size() - // << " in ap " << cast(called)->getName() << "\n"; - if (Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModePrimal) { - - assert(newcalled); - assert(FT); - - if (false) { - badaugmentedfn:; - auto NC = dyn_cast(newcalled); - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - if (NC) - llvm::errs() << " trying to call " << NC->getName() << " " << *FT - << "\n"; - else - llvm::errs() << " trying to call " << *newcalled << " " << *FT - << "\n"; - - for (unsigned i = 0; i < pre_args.size(); ++i) { - assert(pre_args[i]); - assert(pre_args[i]->getType()); - llvm::errs() << "args[" << i << "] = " << *pre_args[i] - << " FT:" << *FT->getParamType(i) << "\n"; - } - assert(0 && "calling with wrong number of arguments"); - exit(1); - } - - if (pre_args.size() != FT->getNumParams()) - goto badaugmentedfn; - - for (unsigned i = 0; i < pre_args.size(); ++i) { - if (pre_args[i]->getType() == FT->getParamType(i)) - continue; - else if (!call.getCalledFunction()) - pre_args[i] = - BuilderZ.CreateBitCast(pre_args[i], FT->getParamType(i)); - else - goto badaugmentedfn; - } - - augmentcall = BuilderZ.CreateCall( - FT, newcalled, pre_args, - gutils->getInvertedBundles(&call, PreBundleTypes, BuilderZ, - /*lookup*/ false)); - augmentcall->setCallingConv(call.getCallingConv()); - augmentcall->setDebugLoc( - gutils->getNewFromOriginal(call.getDebugLoc())); - - for (auto pair : preByVal) { - augmentcall->addParamAttr( - pair.first, Attribute::getWithByValType(augmentcall->getContext(), - pair.second)); - } - - for (auto &pair : structAttrs) { - for (auto val : pair.second) - augmentcall->addParamAttr(pair.first, val); - } - - if (!augmentcall->getType()->isVoidTy()) - augmentcall->setName(call.getName() + "_augmented"); - - if (tapeIdx) { - auto tval = *tapeIdx; - tape = (tval == -1) ? augmentcall - : BuilderZ.CreateExtractValue( - augmentcall, {(unsigned)tval}, "subcache"); - if (tape->getType()->isEmptyTy()) { - auto tt = tape->getType(); - gutils->erase(cast(tape)); - tape = UndefValue::get(tt); - } else { - gutils->TapesToPreventRecomputation.insert(cast(tape)); - } - tape = gutils->cacheForReverse( - BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ)); - } - - if (subretused) { - Value *dcall = nullptr; - assert(returnIdx); - assert(augmentcall); - auto rval = *returnIdx; - dcall = (rval < 0) ? augmentcall - : BuilderZ.CreateExtractValue(augmentcall, - {(unsigned)rval}); - gutils->originalToNewFn[&call] = dcall; - gutils->newToOriginalFn.erase(newCall); - gutils->newToOriginalFn[dcall] = &call; - - assert(dcall->getType() == call.getType()); - assert(dcall); - - if (!gutils->isConstantValue(&call)) { - if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) { - } else if (Mode != DerivativeMode::ReverseModePrimal) { - ((DiffeGradientUtils *)gutils)->differentials[dcall] = - ((DiffeGradientUtils *)gutils)->differentials[newCall]; - ((DiffeGradientUtils *)gutils)->differentials.erase(newCall); - } - } - assert(dcall->getType() == call.getType()); - gutils->replaceAWithB(newCall, dcall); - - if (isa(dcall) && !isa(dcall)) { - cast(dcall)->takeName(newCall); - } - - if (Mode == DerivativeMode::ReverseModePrimal && - !gutils->unnecessaryIntermediates.count(&call)) { - - std::map Seen; - bool primalNeededInReverse = false; - for (auto pair : gutils->knownRecomputeHeuristic) - if (!pair.second) { - if (pair.first == &call) { - primalNeededInReverse = true; - break; - } else { - Seen[UsageKey(pair.first, QueryType::Primal)] = false; - } - } - if (!primalNeededInReverse) { - - auto minCutMode = (Mode == DerivativeMode::ReverseModePrimal) - ? DerivativeMode::ReverseModeGradient - : Mode; - primalNeededInReverse = - DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &call, minCutMode, Seen, - oldUnreachable); - } - if (primalNeededInReverse) - gutils->cacheForReverse( - BuilderZ, dcall, getIndex(&call, CacheType::Self, BuilderZ)); - } - BuilderZ.SetInsertPoint(newCall->getNextNode()); - gutils->erase(newCall); - } else { - BuilderZ.SetInsertPoint(BuilderZ.GetInsertPoint()->getNextNode()); - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - gutils->originalToNewFn[&call] = augmentcall; - gutils->newToOriginalFn[augmentcall] = &call; - } - - } else { - if (subdata && subdata->returns.find(AugmentedStruct::Tape) == - subdata->returns.end()) { - } else { - // assert(!tape); - // assert(subdata); - if (FT) { - if (!tape) { - assert(tapeIdx); - auto tval = *tapeIdx; -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - tape = BuilderZ.CreatePHI( - (tapeIdx == -1) ? FT->getReturnType() - : cast(FT->getReturnType()) - ->getElementType(tval), - 1, "tapeArg"); - } - tape = gutils->cacheForReverse( - BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ)); - } - } - - if (subretused) { - Intrinsic::ID ID = Intrinsic::not_intrinsic; - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &call, Mode, oldUnreachable) && - !gutils->unnecessaryIntermediates.count(&call)) { - - if (!isMemFreeLibMFunction(getFuncNameFromCall(&call), &ID)) { - -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - auto idx = getIndex(&call, CacheType::Self, BuilderZ); - if (idx == IndexMappingError) { - std::string str; - raw_string_ostream ss(str); - ss << "Failed to compute consistent cache index for operation: " - << call << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(&call), - ErrorType::InternalError, nullptr, nullptr, - nullptr); - } else { - EmitFailure("GetIndexError", call.getDebugLoc(), &call, - ss.str()); - } - } else { - if (Mode == DerivativeMode::ReverseModeCombined) - cachereplace = newCall; - else - cachereplace = BuilderZ.CreatePHI( - call.getType(), 1, call.getName() + "_tmpcacheB"); - cachereplace = - gutils->cacheForReverse(BuilderZ, cachereplace, idx); - } - } - } else { -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - auto pn = BuilderZ.CreatePHI( - call.getType(), 1, (call.getName() + "_replacementE").str()); - gutils->fictiousPHIs[pn] = &call; - cachereplace = pn; - } - } else { - // TODO move right after newCall for the insertion point of BuilderZ - - BuilderZ.SetInsertPoint(BuilderZ.GetInsertPoint()->getNextNode()); - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } - } - - auto ifound = gutils->invertedPointers.find(&call); - if (ifound != gutils->invertedPointers.end()) { - auto placeholder = cast(&*ifound->second); - - bool subcheck = (subretType == DIFFE_TYPE::DUP_ARG || - subretType == DIFFE_TYPE::DUP_NONEED); - - //! We only need the shadow pointer for non-forward Mode if it is used - //! in a non return setting - bool hasNonReturnUse = false; - for (auto use : call.users()) { - if (Mode == DerivativeMode::ReverseModePrimal || - !isa(use)) { - hasNonReturnUse = true; - } - } - - if (subcheck && hasNonReturnUse) { - - Value *newip = nullptr; - if (Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModePrimal) { - - if (!differetIdx) { - std::string str; - raw_string_ostream ss(str); - ss << "Did not have return index set when differentiating " - "function\n"; - ss << " call" << call << "\n"; - ss << " augmentcall" << *augmentcall << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(&call), - ErrorType::InternalError, nullptr, nullptr, - nullptr); - } else { - EmitFailure("GetIndexError", call.getDebugLoc(), &call, - ss.str()); - } - placeholder->replaceAllUsesWith( - UndefValue::get(placeholder->getType())); - if (placeholder == &*BuilderZ.GetInsertPoint()) { - BuilderZ.SetInsertPoint(placeholder->getNextNode()); - } - gutils->erase(placeholder); - } else { - auto drval = *differetIdx; - newip = (drval < 0) - ? augmentcall - : BuilderZ.CreateExtractValue(augmentcall, - {(unsigned)drval}, - call.getName() + "'ac"); - assert(newip->getType() == placeholder->getType()); - placeholder->replaceAllUsesWith(newip); - if (placeholder == &*BuilderZ.GetInsertPoint()) { - BuilderZ.SetInsertPoint(placeholder->getNextNode()); - } - gutils->erase(placeholder); - } - } else { - newip = placeholder; - } - - newip = gutils->cacheForReverse( - BuilderZ, newip, getIndex(&call, CacheType::Shadow, BuilderZ)); - - gutils->invertedPointers.insert(std::make_pair( - (const Value *)&call, InvertedPointerVH(gutils, newip))); - } else { - gutils->invertedPointers.erase(ifound); - if (placeholder == &*BuilderZ.GetInsertPoint()) { - BuilderZ.SetInsertPoint(placeholder->getNextNode()); - } - gutils->erase(placeholder); - } - } - - if (fnandtapetype && fnandtapetype->tapeType && - (Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit) && - shouldFree()) { - assert(tape); - auto tapep = BuilderZ.CreatePointerCast( - tape, PointerType::get( - fnandtapetype->tapeType, - cast(tape->getType())->getAddressSpace())); - auto truetape = - BuilderZ.CreateLoad(fnandtapetype->tapeType, tapep, "tapeld"); - truetape->setMetadata("enzyme_mustcache", - MDNode::get(truetape->getContext(), {})); - - CreateDealloc(BuilderZ, tape); - tape = truetape; - } - } else { - auto ifound = gutils->invertedPointers.find(&call); - if (ifound != gutils->invertedPointers.end()) { - auto placeholder = cast(&*ifound->second); - gutils->invertedPointers.erase(ifound); - gutils->erase(placeholder); - } - if (/*!topLevel*/ Mode != DerivativeMode::ReverseModeCombined && - subretused && !call.doesNotAccessMemory()) { - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &call, Mode, oldUnreachable) && - !gutils->unnecessaryIntermediates.count(&call)) { - assert(!replaceFunction); -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - cachereplace = BuilderZ.CreatePHI(call.getType(), 1, - call.getName() + "_cachereplace2"); - cachereplace = gutils->cacheForReverse( - BuilderZ, cachereplace, - getIndex(&call, CacheType::Self, BuilderZ)); - } else { -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - auto pn = BuilderZ.CreatePHI(call.getType(), 1, - call.getName() + "_replacementC"); - gutils->fictiousPHIs[pn] = &call; - cachereplace = pn; - } - } - - if (!subretused && !replaceFunction) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } - - // Note here down only contains the reverse bits - if (Mode == DerivativeMode::ReverseModePrimal) { - return; - } - - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - Value *newcalled = nullptr; - FunctionType *FT = nullptr; - - DerivativeMode subMode = (replaceFunction || !modifyPrimal) - ? DerivativeMode::ReverseModeCombined - : DerivativeMode::ReverseModeGradient; - if (called) { - if (Mode == DerivativeMode::ReverseModeGradient && subdata) { - for (size_t i = 0; i < argsInverted.size(); i++) { - if (subdata->constant_args[i] == argsInverted[i]) - continue; - assert(subdata->constant_args[i] == DIFFE_TYPE::DUP_ARG); - assert(argsInverted[i] == DIFFE_TYPE::DUP_NONEED); - argsInverted[i] = DIFFE_TYPE::DUP_ARG; - } - } - - newcalled = gutils->Logic.CreatePrimalAndGradient( - RequestContext(&call, &Builder2), - (ReverseCacheKey){ - .todiff = cast(called), - .retType = subretType, - .constant_args = argsInverted, - .subsequent_calls_may_write = subsequent_calls_may_write, - .overwritten_args = overwritten_args, - .returnUsed = replaceFunction && subretused, - .shadowReturnUsed = shadowReturnUsed && replaceFunction, - .mode = subMode, - .width = gutils->getWidth(), - .freeMemory = true, - .AtomicAdd = gutils->AtomicAdd, - .additionalType = tape ? tape->getType() : nullptr, - .forceAnonymousTape = false, - .typeInfo = nextTypeInfo, - .runtimeActivity = gutils->runtimeActivity}, - TR.analyzer->interprocedural, subdata); - if (!newcalled) - return; - FT = cast(newcalled)->getFunctionType(); - } else { - - assert(subMode != DerivativeMode::ReverseModeCombined); - - auto callval = call.getCalledOperand(); - - if (gutils->isConstantValue(callval)) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << *gutils->oldFunc << "\n"; - ss << "in Mode: " << to_string(Mode) << "\n"; - ss << " orig: " << call << " callval: " << *callval << "\n"; - ss << " constant function being called, but active call instruction\n"; - auto val = EmitNoDerivativeError(ss.str(), call, gutils, Builder2); - if (val) - newcalled = val; - else - newcalled = - UndefValue::get(gutils->getShadowType(callval->getType())); - } else { - newcalled = lookup(gutils->invertPointerM(callval, Builder2), Builder2); - } - - auto ft = call.getFunctionType(); - - auto res = - getDefaultFunctionTypeForGradient(ft, /*subretType*/ subretType); - // TODO Note there is empty tape added here, replace with generic - res.first.push_back(getInt8PtrTy(newcalled->getContext())); - FT = FunctionType::get( - StructType::get(newcalled->getContext(), res.second), res.first, - ft->isVarArg()); - auto fptype = PointerType::getUnqual(FT); - newcalled = - Builder2.CreatePointerCast(newcalled, PointerType::getUnqual(fptype)); - newcalled = Builder2.CreateLoad( - fptype, Builder2.CreateConstGEP1_64(fptype, newcalled, 1)); - } - - if (subretType == DIFFE_TYPE::OUT_DIFF) { - args.push_back(diffe(&call, Builder2)); - } - - if (tape) { - auto ntape = gutils->lookupM(tape, Builder2); - assert(ntape); - assert(ntape->getType()); - args.push_back(ntape); - } - - assert(newcalled); - assert(FT); - - if (false) { - badfn:; - auto NC = dyn_cast(newcalled); - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - if (NC) - llvm::errs() << " trying to call " << NC->getName() << " " << *FT - << "\n"; - else - llvm::errs() << " trying to call " << *newcalled << " " << *FT << "\n"; - - for (unsigned i = 0; i < args.size(); ++i) { - assert(args[i]); - assert(args[i]->getType()); - llvm::errs() << "args[" << i << "] = " << *args[i] - << " FT:" << *FT->getParamType(i) << "\n"; - } - assert(0 && "calling with wrong number of arguments"); - exit(1); - } - - if (args.size() != FT->getNumParams()) - goto badfn; - - for (unsigned i = 0; i < args.size(); ++i) { - if (args[i]->getType() == FT->getParamType(i)) - continue; - else if (!call.getCalledFunction()) - args[i] = Builder2.CreateBitCast(args[i], FT->getParamType(i)); - else - goto badfn; - } - - CallInst *diffes = - Builder2.CreateCall(FT, newcalled, args, - gutils->getInvertedBundles( - &call, BundleTypes, Builder2, /*lookup*/ true)); - diffes->setCallingConv(call.getCallingConv()); - diffes->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc())); - - for (auto pair : gradByVal) { - diffes->addParamAttr(pair.first, Attribute::getWithByValType( - diffes->getContext(), pair.second)); - } - - for (auto &pair : structAttrs) { - for (auto val : pair.second) - diffes->addParamAttr(pair.first, val); - } - - unsigned structidx = 0; - if (replaceFunction) { - if (subretused) - structidx++; - if (shadowReturnUsed) - structidx++; - } - - for (unsigned i = 0; i < call.arg_size(); ++i) { - if (argsInverted[i] == DIFFE_TYPE::OUT_DIFF) { - Value *diffeadd = Builder2.CreateExtractValue(diffes, {structidx}); - ++structidx; - - if (!gutils->isConstantValue(call.getArgOperand(i))) { - size_t size = 1; - if (call.getArgOperand(i)->getType()->isSized()) - size = (gutils->newFunc->getParent() - ->getDataLayout() - .getTypeSizeInBits(call.getArgOperand(i)->getType()) + - 7) / - 8; - - addToDiffe(call.getArgOperand(i), diffeadd, Builder2, - TR.addingType(size, call.getArgOperand(i))); - } - } - } - - if (diffes->getType()->isVoidTy()) { - if (structidx != 0) { - llvm::errs() << *gutils->oldFunc->getParent() << "\n"; - llvm::errs() << "diffes: " << *diffes << " structidx=" << structidx - << " subretused=" << subretused - << " shadowReturnUsed=" << shadowReturnUsed << "\n"; - } - assert(structidx == 0); - } else { - assert(cast(diffes->getType())->getNumElements() == - structidx); - } - - if (subretType == DIFFE_TYPE::OUT_DIFF) - setDiffe(&call, - Constant::getNullValue(gutils->getShadowType(call.getType())), - Builder2); - - if (replaceFunction) { - - // if a function is replaced for joint forward/reverse, handle inverted - // pointers - auto ifound = gutils->invertedPointers.find(&call); - if (ifound != gutils->invertedPointers.end()) { - auto placeholder = cast(&*ifound->second); - gutils->invertedPointers.erase(ifound); - if (shadowReturnUsed) { - dumpMap(gutils->invertedPointers); - auto dretval = cast( - Builder2.CreateExtractValue(diffes, {subretused ? 1U : 0U})); - /* todo handle this case later */ - assert(!subretused); - gutils->invertedPointers.insert(std::make_pair( - (const Value *)&call, InvertedPointerVH(gutils, dretval))); - } - gutils->erase(placeholder); - } - - Instruction *retval = nullptr; - - if (subretused) { - retval = cast(Builder2.CreateExtractValue(diffes, {0})); - if (retval) { - gutils->replaceAndRemoveUnwrapCacheFor(newCall, retval); - } - gutils->replaceAWithB(newCall, retval, /*storeInCache*/ true); - } else { - eraseIfUnused(call, /*erase*/ false, /*check*/ false); - } - - for (auto a : postCreate) { - a->moveBefore(*Builder2.GetInsertBlock(), Builder2.GetInsertPoint()); - } - - gutils->originalToNewFn[&call] = retval ? retval : diffes; - gutils->newToOriginalFn.erase(newCall); - gutils->newToOriginalFn[retval ? retval : diffes] = &call; - - gutils->erase(newCall); - - return; - } - - if (cachereplace) { - if (subretused) { - Value *dcall = nullptr; - assert(cachereplace->getType() == call.getType()); - assert(dcall == nullptr); - dcall = cachereplace; - assert(dcall); - - if (!gutils->isConstantValue(&call)) { - gutils->originalToNewFn[&call] = dcall; - gutils->newToOriginalFn.erase(newCall); - gutils->newToOriginalFn[dcall] = &call; - if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) { - } else { - ((DiffeGradientUtils *)gutils)->differentials[dcall] = - ((DiffeGradientUtils *)gutils)->differentials[newCall]; - ((DiffeGradientUtils *)gutils)->differentials.erase(newCall); - } - } - assert(dcall->getType() == call.getType()); - newCall->replaceAllUsesWith(dcall); - if (isa(dcall) && !isa(dcall)) { - cast(dcall)->takeName(&call); - } - gutils->erase(newCall); - } else { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - if (augmentcall) { - gutils->originalToNewFn[&call] = augmentcall; - gutils->newToOriginalFn.erase(newCall); - gutils->newToOriginalFn[augmentcall] = &call; - } - } - } - return; - } - - void handleMPI(llvm::CallInst &call, llvm::Function *called, - llvm::StringRef funcName); - - bool handleKnownCallDerivatives(llvm::CallInst &call, llvm::Function *called, - llvm::StringRef funcName, - bool subsequent_calls_may_write, - const std::vector &overwritten_args, - llvm::CallInst *const newCall); - - // Return - void visitCallInst(llvm::CallInst &call) { - using namespace llvm; - - // When compiling Enzyme against standard LLVM, and not Intel's - // modified version of LLVM, the intrinsic `llvm.intel.subscript` is - // not fully understood by LLVM. One of the results of this is that the - // visitor dispatches to visitCallInst, rather than visitIntrinsicInst, when - // presented with the intrinsic - hence why we are handling it here. - if (startsWith(getFuncNameFromCall(&call), ("llvm.intel.subscript"))) { - assert(isa(call)); - visitIntrinsicInst(cast(call)); - return; - } - - CallInst *const newCall = cast(gutils->getNewFromOriginal(&call)); - IRBuilder<> BuilderZ(newCall); - BuilderZ.setFastMathFlags(getFast()); - - if (overwritten_args_map.find(&call) == overwritten_args_map.end() && - Mode != DerivativeMode::ForwardMode && - Mode != DerivativeMode::ForwardModeError) { - llvm::errs() << " call: " << call << "\n"; - for (auto &pair : overwritten_args_map) { - llvm::errs() << " + " << *pair.first << "\n"; - } - } - - assert(overwritten_args_map.find(&call) != overwritten_args_map.end() || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError); - const bool subsequent_calls_may_write = - (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) - ? false - : overwritten_args_map.find(&call)->second.first; - const std::vector &overwritten_args = - (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) - ? std::vector() - : overwritten_args_map.find(&call)->second.second; - - auto called = getFunctionFromCall(&call); - StringRef funcName = getFuncNameFromCall(&call); - - bool subretused = false; - bool shadowReturnUsed = false; - auto smode = Mode; - if (smode == DerivativeMode::ReverseModeGradient) - smode = DerivativeMode::ReverseModePrimal; - DIFFE_TYPE subretType = gutils->getReturnDiffeType( - &call, &subretused, &shadowReturnUsed, smode); - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - auto found = customFwdCallHandlers.find(funcName); - if (found != customFwdCallHandlers.end()) { - Value *invertedReturn = nullptr; - auto ifound = gutils->invertedPointers.find(&call); - if (ifound != gutils->invertedPointers.end()) { - invertedReturn = cast(&*ifound->second); - } - - Value *normalReturn = subretused ? newCall : nullptr; - - bool noMod = found->second(BuilderZ, &call, *gutils, normalReturn, - invertedReturn); - if (noMod) { - if (subretused) - assert(normalReturn == newCall); - eraseIfUnused(call); - } - - ifound = gutils->invertedPointers.find(&call); - if (ifound != gutils->invertedPointers.end()) { - auto placeholder = cast(&*ifound->second); - if (invertedReturn && invertedReturn != placeholder) { - if (invertedReturn->getType() != - gutils->getShadowType(call.getType())) { - llvm::errs() << " o: " << call << "\n"; - llvm::errs() << " ot: " << *call.getType() << "\n"; - llvm::errs() << " ir: " << *invertedReturn << "\n"; - llvm::errs() << " irt: " << *invertedReturn->getType() << "\n"; - llvm::errs() << " p: " << *placeholder << "\n"; - llvm::errs() << " PT: " << *placeholder->getType() << "\n"; - llvm::errs() << " newCall: " << *newCall << "\n"; - llvm::errs() << " newCallT: " << *newCall->getType() << "\n"; - } - assert(invertedReturn->getType() == - gutils->getShadowType(call.getType())); - placeholder->replaceAllUsesWith(invertedReturn); - gutils->erase(placeholder); - gutils->invertedPointers.insert( - std::make_pair((const Value *)&call, - InvertedPointerVH(gutils, invertedReturn))); - } else { - gutils->invertedPointers.erase(&call); - gutils->erase(placeholder); - } - } - - if (normalReturn && normalReturn != newCall) { - assert(normalReturn->getType() == newCall->getType()); - gutils->replaceAWithB(newCall, normalReturn); - gutils->erase(newCall); - } - return; - } - } - - if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModeGradient) { - auto found = customCallHandlers.find(funcName); - if (found != customCallHandlers.end()) { - IRBuilder<> Builder2(&call); - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) - getReverseBuilder(Builder2); - - Value *invertedReturn = nullptr; - auto ifound = gutils->invertedPointers.find(&call); - PHINode *placeholder = nullptr; - if (ifound != gutils->invertedPointers.end()) { - placeholder = cast(&*ifound->second); - if (shadowReturnUsed) - invertedReturn = placeholder; - } - - Value *normalReturn = subretused ? newCall : nullptr; - - Value *tape = nullptr; - - Type *tapeType = nullptr; - - if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined) { - bool noMod = found->second.first(BuilderZ, &call, *gutils, - normalReturn, invertedReturn, tape); - if (noMod) { - if (subretused) - assert(normalReturn == newCall); - eraseIfUnused(call); - } - if (tape) { - tapeType = tape->getType(); - gutils->cacheForReverse(BuilderZ, tape, - getIndex(&call, CacheType::Tape, BuilderZ)); - } - if (Mode == DerivativeMode::ReverseModePrimal) { - assert(augmentedReturn); - auto subaugmentations = - (std::map - *)&augmentedReturn->subaugmentations; - insert_or_assign2( - *subaugmentations, &call, (AugmentedReturn *)tapeType); - } - } - - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - if (Mode == DerivativeMode::ReverseModeGradient && - augmentedReturn->tapeIndices.find( - std::make_pair(&call, CacheType::Tape)) != - augmentedReturn->tapeIndices.end()) { - assert(augmentedReturn); - auto subaugmentations = - (std::map - *)&augmentedReturn->subaugmentations; - auto fd = subaugmentations->find(&call); - assert(fd != subaugmentations->end()); - // Note we are using the storage space here to persist - // the LLVM type, as storing a new augmentedReturn has issues - // regarding persisting the data structure, and when it will - // be freed, since it will no longer live in the map in - // EnzymeLogic. - tapeType = (llvm::Type *)fd->second; - -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - tape = BuilderZ.CreatePHI(tapeType, 0); - tape = gutils->cacheForReverse( - BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ), - /*ignoreType*/ true); - } - if (tape) - tape = gutils->lookupM(tape, Builder2); - found->second.second(Builder2, &call, *(DiffeGradientUtils *)gutils, - tape); - } - - if (placeholder) { - if (!shadowReturnUsed) { - gutils->invertedPointers.erase(&call); - gutils->erase(placeholder); - } else { - if (invertedReturn && invertedReturn != placeholder) { - if (invertedReturn->getType() != - gutils->getShadowType(call.getType())) { - llvm::errs() << " o: " << call << "\n"; - llvm::errs() << " ot: " << *call.getType() << "\n"; - llvm::errs() << " ir: " << *invertedReturn << "\n"; - llvm::errs() << " irt: " << *invertedReturn->getType() << "\n"; - llvm::errs() << " p: " << *placeholder << "\n"; - llvm::errs() << " PT: " << *placeholder->getType() << "\n"; - llvm::errs() << " newCall: " << *newCall << "\n"; - llvm::errs() << " newCallT: " << *newCall->getType() << "\n"; - } - assert(invertedReturn->getType() == - gutils->getShadowType(call.getType())); - placeholder->replaceAllUsesWith(invertedReturn); - gutils->erase(placeholder); - invertedReturn = gutils->cacheForReverse( - BuilderZ, invertedReturn, - getIndex(&call, CacheType::Shadow, BuilderZ)); - } else { - auto idx = getIndex(&call, CacheType::Shadow, BuilderZ); - invertedReturn = - gutils->cacheForReverse(BuilderZ, placeholder, idx); - if (idx == IndexMappingError) { - if (placeholder->getType() != invertedReturn->getType()) - llvm::errs() << " place: " << *placeholder - << " invRet: " << *invertedReturn; - placeholder->replaceAllUsesWith(invertedReturn); - gutils->erase(placeholder); - } - } - - gutils->invertedPointers.insert( - std::make_pair((const Value *)&call, - InvertedPointerVH(gutils, invertedReturn))); - } - } - - bool primalNeededInReverse; - - if (gutils->knownRecomputeHeuristic.count(&call)) { - primalNeededInReverse = !gutils->knownRecomputeHeuristic[&call]; - } else { - std::map Seen; - for (auto pair : gutils->knownRecomputeHeuristic) - if (!pair.second) - Seen[UsageKey(pair.first, QueryType::Primal)] = false; - primalNeededInReverse = - DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &call, Mode, Seen, oldUnreachable); - } - if (subretused && primalNeededInReverse) { - if (normalReturn != newCall) { - assert(normalReturn->getType() == newCall->getType()); - gutils->replaceAWithB(newCall, normalReturn); - BuilderZ.SetInsertPoint(newCall->getNextNode()); - gutils->erase(newCall); - } - normalReturn = gutils->cacheForReverse( - BuilderZ, normalReturn, - getIndex(&call, CacheType::Self, BuilderZ)); - } else { - if (normalReturn && normalReturn != newCall) { - assert(normalReturn->getType() == newCall->getType()); - assert(Mode != DerivativeMode::ReverseModeGradient); - gutils->replaceAWithB(newCall, normalReturn); - BuilderZ.SetInsertPoint(newCall->getNextNode()); - gutils->erase(newCall); - } else if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } - return; - } - } - - if (called) { - if (funcName == "__kmpc_fork_call") { - visitOMPCall(call); - return; - } - } - - if (handleKnownCallDerivatives(call, called, funcName, - subsequent_calls_may_write, overwritten_args, - newCall)) - return; - - bool useConstantFallback = - DifferentialUseAnalysis::callShouldNotUseDerivative(gutils, call); - if (!useConstantFallback) { - if (gutils->isConstantInstruction(&call) && - gutils->isConstantValue(&call)) { - EmitWarning("ConstnatFallback", call, - "Call was deduced inactive but still doing differential " - "rewrite as it may escape an allocation", - call); - } - } - if (useConstantFallback) { - if (!gutils->isConstantValue(&call)) { - auto found = gutils->invertedPointers.find(&call); - if (found != gutils->invertedPointers.end()) { - PHINode *placeholder = cast(&*found->second); - gutils->invertedPointers.erase(found); - gutils->erase(placeholder); - } - } - bool noFree = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError; - noFree |= call.hasFnAttr(Attribute::NoFree); - if (!noFree && called) { - noFree |= called->hasFnAttribute(Attribute::NoFree); - } - - std::map CacheResults; - for (auto pair : gutils->knownRecomputeHeuristic) { - if (!pair.second || gutils->unnecessaryIntermediates.count( - cast(pair.first))) { - CacheResults[UsageKey(pair.first, QueryType::Primal)] = false; - } - } - - if (!noFree && !EnzymeGlobalActivity) { - bool mayActiveFree = false; - for (unsigned i = 0; i < call.arg_size(); ++i) { - Value *a = call.getOperand(i); - - if (EnzymeJuliaAddrLoad && isSpecialPtr(a->getType())) - continue; - // if could not be a pointer, it cannot be freed - if (!TR.query(a)[{-1}].isPossiblePointer()) - continue; - // if active value, we need to do memory preservation - if (!gutils->isConstantValue(a)) { - mayActiveFree = true; - break; - } - // if used in reverse (even if just primal), need to do - // memory preservation - const auto obj = getBaseObject(a); - // If not allocation/allocainst, it is possible this aliases - // a pointer needed in the reverse pass - bool isAllocation = false; - for (auto objv = obj;;) { - if (isAllocationCall(objv, gutils->TLI)) { - isAllocation = true; - break; - } - if (auto objC = dyn_cast(objv)) - if (auto F = getFunctionFromCall(objC)) - if (!F->empty()) { - SmallPtrSet set; - for (auto &B : *F) { - if (auto RI = dyn_cast(B.getTerminator())) { - auto v = getBaseObject(RI->getOperand(0)); - if (isa(v)) - continue; - set.insert(v); - } - } - if (set.size() == 1) { - objv = *set.begin(); - continue; - } - } - break; - } - if (!isAllocation) { - mayActiveFree = true; - break; - } - { - auto found = gutils->knownRecomputeHeuristic.find(obj); - if (found != gutils->knownRecomputeHeuristic.end()) { - if (!found->second) { - auto CacheResults2(CacheResults); - CacheResults2.erase(UsageKey(obj, QueryType::Primal)); - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, obj, - DerivativeMode::ReverseModeGradient, - CacheResults2, oldUnreachable)) { - mayActiveFree = true; - break; - } - } - continue; - } - } - auto CacheResults2(CacheResults); - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, obj, - DerivativeMode::ReverseModeGradient, - CacheResults2, oldUnreachable)) { - mayActiveFree = true; - break; - } - } - if (!mayActiveFree) - noFree = true; - } - if (!noFree) { - auto callval = call.getCalledOperand(); - if (!isa(callval)) - callval = gutils->getNewFromOriginal(callval); - newCall->setCalledOperand(gutils->Logic.CreateNoFree( - RequestContext(&call, &BuilderZ), callval)); - } - if (gutils->knownRecomputeHeuristic.find(&call) != - gutils->knownRecomputeHeuristic.end()) { - if (!gutils->knownRecomputeHeuristic[&call]) { - gutils->cacheForReverse(BuilderZ, newCall, - getIndex(&call, CacheType::Self, BuilderZ)); - eraseIfUnused(call); - return; - } - } - - // If we need this value and it is illegal to recompute it (it writes or - // may load overwritten data) - // Store and reload it - if (Mode != DerivativeMode::ReverseModeCombined && - Mode != DerivativeMode::ForwardMode && - Mode != DerivativeMode::ForwardModeError && subretused && - (call.mayWriteToMemory() || - !gutils->legalRecompute(&call, ValueToValueMapTy(), nullptr))) { - if (!gutils->unnecessaryIntermediates.count(&call)) { - - std::map Seen; - bool primalNeededInReverse = false; - for (auto pair : gutils->knownRecomputeHeuristic) - if (!pair.second) { - if (pair.first == &call) { - primalNeededInReverse = true; - break; - } else { - Seen[UsageKey(pair.first, QueryType::Primal)] = false; - } - } - if (!primalNeededInReverse) { - - auto minCutMode = (Mode == DerivativeMode::ReverseModePrimal) - ? DerivativeMode::ReverseModeGradient - : Mode; - primalNeededInReverse = - DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &call, minCutMode, Seen, - oldUnreachable); - } - if (primalNeededInReverse) { - gutils->cacheForReverse(BuilderZ, newCall, - getIndex(&call, CacheType::Self, BuilderZ)); - eraseIfUnused(call); - return; - } - } - // Force erasure in reverse pass, since cached if needed - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - else - eraseIfUnused(call); - return; - } - - // If this call may write to memory and is a copy (in the just reverse - // pass), erase it - // Any uses of it should be handled by the case above so it is safe to - // RAUW - if (call.mayWriteToMemory() && - (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit)) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // if call does not write memory and isn't used, we can erase it - if (!call.mayWriteToMemory() && !subretused) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - return; - } - - return recursivelyHandleSubfunction( - call, called, subsequent_calls_may_write, overwritten_args, - shadowReturnUsed, subretType, subretused); - } -}; - -#endif // ENZYME_ADJOINT_GENERATOR_H diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp deleted file mode 100644 index c64118ff6821..000000000000 --- a/enzyme/Enzyme/CApi.cpp +++ /dev/null @@ -1,2079 +0,0 @@ -//===- CApi.cpp - Enzyme API exported to C for external use -----------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file defines various utility functions of Enzyme for access via C -// -//===----------------------------------------------------------------------===// -#include "CApi.h" -#if LLVM_VERSION_MAJOR >= 16 -#define private public -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#undef private -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "DiffeGradientUtils.h" -#include "DifferentialUseAnalysis.h" -#include "EnzymeLogic.h" -#include "GradientUtils.h" -#include "LibraryFuncs.h" -#if LLVM_VERSION_MAJOR >= 16 -#include "llvm/Analysis/TargetLibraryInfo.h" -#else -#include "SCEV/TargetLibraryInfo.h" -#endif -#include "TraceInterface.h" - -// #include "llvm/ADT/Triple.h" -#include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/IR/DIBuilder.h" -#include "llvm/IR/MDBuilder.h" -#include "llvm/Transforms/Utils/Cloning.h" - -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/Transforms/IPO/Attributor.h" - -#define addAttribute addAttributeAtIndex -#define removeAttribute removeAttributeAtIndex -#define getAttribute getAttributeAtIndex -#define hasAttribute hasAttributeAtIndex - -using namespace llvm; - -TargetLibraryInfo eunwrap(LLVMTargetLibraryInfoRef P) { - return TargetLibraryInfo(*reinterpret_cast(P)); -} - -EnzymeLogic &eunwrap(EnzymeLogicRef LR) { return *(EnzymeLogic *)LR; } - -TraceInterface *eunwrap(EnzymeTraceInterfaceRef Ref) { - return (TraceInterface *)Ref; -} - -TypeAnalysis &eunwrap(EnzymeTypeAnalysisRef TAR) { - return *(TypeAnalysis *)TAR; -} -AugmentedReturn *eunwrap(EnzymeAugmentedReturnPtr ARP) { - return (AugmentedReturn *)ARP; -} -EnzymeAugmentedReturnPtr ewrap(const AugmentedReturn &AR) { - return (EnzymeAugmentedReturnPtr)(&AR); -} - -ConcreteType eunwrap(CConcreteType CDT, llvm::LLVMContext &ctx) { - switch (CDT) { - case DT_Anything: - return BaseType::Anything; - case DT_Integer: - return BaseType::Integer; - case DT_Pointer: - return BaseType::Pointer; - case DT_Half: - return ConcreteType(llvm::Type::getHalfTy(ctx)); - case DT_Float: - return ConcreteType(llvm::Type::getFloatTy(ctx)); - case DT_Double: - return ConcreteType(llvm::Type::getDoubleTy(ctx)); - case DT_X86_FP80: - return ConcreteType(llvm::Type::getX86_FP80Ty(ctx)); - case DT_BFloat16: - return ConcreteType(llvm::Type::getBFloatTy(ctx)); - case DT_Unknown: - return BaseType::Unknown; - } - llvm_unreachable("Unknown concrete type to unwrap"); -} - -std::vector eunwrap(IntList IL) { - std::vector v; - for (size_t i = 0; i < IL.size; i++) { - v.push_back((int)IL.data[i]); - } - return v; -} -std::set eunwrap64(IntList IL) { - std::set v; - for (size_t i = 0; i < IL.size; i++) { - v.insert((int64_t)IL.data[i]); - } - return v; -} -TypeTree eunwrap(CTypeTreeRef CTT) { return *(TypeTree *)CTT; } - -CConcreteType ewrap(const ConcreteType &CT) { - if (auto flt = CT.isFloat()) { - if (flt->isHalfTy()) - return DT_Half; - if (flt->isFloatTy()) - return DT_Float; - if (flt->isDoubleTy()) - return DT_Double; - if (flt->isX86_FP80Ty()) - return DT_X86_FP80; - if (flt->isBFloatTy()) - return DT_BFloat16; - } else { - switch (CT.SubTypeEnum) { - case BaseType::Integer: - return DT_Integer; - case BaseType::Pointer: - return DT_Pointer; - case BaseType::Anything: - return DT_Anything; - case BaseType::Unknown: - return DT_Unknown; - case BaseType::Float: - llvm_unreachable("Illegal conversion of concretetype"); - } - } - llvm_unreachable("Illegal conversion of concretetype"); -} - -IntList ewrap(const std::vector &offsets) { - IntList IL; - IL.size = offsets.size(); - IL.data = new int64_t[IL.size]; - for (size_t i = 0; i < offsets.size(); i++) { - IL.data[i] = offsets[i]; - } - return IL; -} - -CTypeTreeRef ewrap(const TypeTree &TT) { - return (CTypeTreeRef)(new TypeTree(TT)); -} - -FnTypeInfo eunwrap(CFnTypeInfo CTI, llvm::Function *F) { - FnTypeInfo FTI(F); - // auto &ctx = F->getContext(); - FTI.Return = eunwrap(CTI.Return); - - size_t argnum = 0; - for (auto &arg : F->args()) { - FTI.Arguments[&arg] = eunwrap(CTI.Arguments[argnum]); - FTI.KnownValues[&arg] = eunwrap64(CTI.KnownValues[argnum]); - argnum++; - } - return FTI; -} - -extern "C" { - -void EnzymeSetCLBool(void *ptr, uint8_t val) { - auto cl = (llvm::cl::opt *)ptr; - cl->setValue((bool)val); -} - -uint8_t EnzymeGetCLBool(void *ptr) { - auto cl = (llvm::cl::opt *)ptr; - return (uint8_t)(bool)cl->getValue(); -} - -void EnzymeSetCLInteger(void *ptr, int64_t val) { - auto cl = (llvm::cl::opt *)ptr; - cl->setValue((int)val); -} - -int64_t EnzymeGetCLInteger(void *ptr) { - auto cl = (llvm::cl::opt *)ptr; - return (int64_t)cl->getValue(); -} - -EnzymeLogicRef CreateEnzymeLogic(uint8_t PostOpt) { - return (EnzymeLogicRef)(new EnzymeLogic((bool)PostOpt)); -} - -EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M) { - return (EnzymeTraceInterfaceRef)(new StaticTraceInterface(unwrap(M))); -} - -EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface( - LLVMContextRef C, LLVMValueRef getTraceFunction, - LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction, - LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction, - LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction, - LLVMValueRef insertChoiceGradientFunction, - LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction, - LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction, - LLVMValueRef hasChoiceFunction) { - return (EnzymeTraceInterfaceRef)(new StaticTraceInterface( - *unwrap(C), cast(unwrap(getTraceFunction)), - cast(unwrap(getChoiceFunction)), - cast(unwrap(insertCallFunction)), - cast(unwrap(insertChoiceFunction)), - cast(unwrap(insertArgumentFunction)), - cast(unwrap(insertReturnFunction)), - cast(unwrap(insertFunctionFunction)), - cast(unwrap(insertChoiceGradientFunction)), - cast(unwrap(insertArgumentGradientFunction)), - cast(unwrap(newTraceFunction)), - cast(unwrap(freeTraceFunction)), - cast(unwrap(hasCallFunction)), - cast(unwrap(hasChoiceFunction)))); -}; - -EnzymeTraceInterfaceRef -CreateEnzymeDynamicTraceInterface(LLVMValueRef interface, LLVMValueRef F) { - return (EnzymeTraceInterfaceRef)(new DynamicTraceInterface( - unwrap(interface), cast(unwrap(F)))); -} - -void ClearEnzymeLogic(EnzymeLogicRef Ref) { eunwrap(Ref).clear(); } - -void EnzymeLogicErasePreprocessedFunctions(EnzymeLogicRef Ref) { - auto &Logic = eunwrap(Ref); - for (const auto &pair : Logic.PPC.cache) - pair.second->eraseFromParent(); -} - -void FreeEnzymeLogic(EnzymeLogicRef Ref) { delete (EnzymeLogic *)Ref; } - -void FreeTraceInterface(EnzymeTraceInterfaceRef Ref) { - delete (TraceInterface *)Ref; -} - -EnzymeTypeAnalysisRef CreateTypeAnalysis(EnzymeLogicRef Log, - char **customRuleNames, - CustomRuleType *customRules, - size_t numRules) { - TypeAnalysis *TA = new TypeAnalysis(((EnzymeLogic *)Log)->PPC.FAM); - for (size_t i = 0; i < numRules; i++) { - CustomRuleType rule = customRules[i]; - TA->CustomRules[customRuleNames[i]] = - [=](int direction, TypeTree &returnTree, ArrayRef argTrees, - ArrayRef> knownValues, CallBase *call, - TypeAnalyzer *TA) -> uint8_t { - CTypeTreeRef creturnTree = (CTypeTreeRef)(&returnTree); - CTypeTreeRef *cargs = new CTypeTreeRef[argTrees.size()]; - IntList *kvs = new IntList[argTrees.size()]; - for (size_t i = 0; i < argTrees.size(); ++i) { - cargs[i] = (CTypeTreeRef)(&(argTrees[i])); - kvs[i].size = knownValues[i].size(); - kvs[i].data = new int64_t[kvs[i].size]; - size_t j = 0; - for (auto val : knownValues[i]) { - kvs[i].data[j] = val; - j++; - } - } - uint8_t result = rule(direction, creturnTree, cargs, kvs, argTrees.size(), - wrap(call), TA); - delete[] cargs; - for (size_t i = 0; i < argTrees.size(); ++i) { - delete[] kvs[i].data; - } - delete[] kvs; - return result; - }; - } - return (EnzymeTypeAnalysisRef)TA; -} - -void ClearTypeAnalysis(EnzymeTypeAnalysisRef TAR) { eunwrap(TAR).clear(); } - -void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) { - TypeAnalysis *TA = (TypeAnalysis *)TAR; - delete TA; -} - -void *EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI, - LLVMValueRef F) { - FnTypeInfo FTI(eunwrap(CTI, cast(unwrap(F)))); - return (void *)((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer; -} - -void *EnzymeGradientUtilsTypeAnalyzer(GradientUtils *G) { - return (void *)&G->TR.analyzer; -} - -void EnzymeGradientUtilsErase(GradientUtils *G, LLVMValueRef I) { - return G->erase(cast(unwrap(I))); -} -void EnzymeGradientUtilsEraseWithPlaceholder(GradientUtils *G, LLVMValueRef I, - LLVMValueRef orig, uint8_t erase) { - return G->eraseWithPlaceholder(cast(unwrap(I)), - cast(unwrap(orig)), - "_replacementABI", erase != 0); -} - -void EnzymeGradientUtilsReplaceAWithB(GradientUtils *G, LLVMValueRef A, - LLVMValueRef B) { - return G->replaceAWithB(unwrap(A), unwrap(B)); -} - -void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, - CustomShadowFree FHandle) { - shadowHandlers[Name] = [=](IRBuilder<> &B, CallInst *CI, - ArrayRef Args, - GradientUtils *gutils) -> llvm::Value * { - SmallVector refs; - for (auto a : Args) - refs.push_back(wrap(a)); - return unwrap( - AHandle(wrap(&B), wrap(CI), Args.size(), refs.data(), gutils)); - }; - if (FHandle) - shadowErasers[Name] = [=](IRBuilder<> &B, - Value *ToFree) -> llvm::CallInst * { - return cast_or_null(unwrap(FHandle(wrap(&B), wrap(ToFree)))); - }; -} - -void EnzymeRegisterCallHandler(const char *Name, - CustomAugmentedFunctionForward FwdHandle, - CustomFunctionReverse RevHandle) { - auto &pair = customCallHandlers[Name]; - pair.first = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, - Value *&normalReturn, Value *&shadowReturn, - Value *&tape) -> bool { - LLVMValueRef normalR = wrap(normalReturn); - LLVMValueRef shadowR = wrap(shadowReturn); - LLVMValueRef tapeR = wrap(tape); - uint8_t noMod = - FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR); - normalReturn = unwrap(normalR); - shadowReturn = unwrap(shadowR); - tape = unwrap(tapeR); - return noMod != 0; - }; - pair.second = [=](IRBuilder<> &B, CallInst *CI, DiffeGradientUtils &gutils, - Value *tape) { - RevHandle(wrap(&B), wrap(CI), &gutils, wrap(tape)); - }; -} - -void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) { - auto &pair = customFwdCallHandlers[Name]; - pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, - Value *&normalReturn, Value *&shadowReturn) -> bool { - LLVMValueRef normalR = wrap(normalReturn); - LLVMValueRef shadowR = wrap(shadowReturn); - uint8_t noMod = FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR); - normalReturn = unwrap(normalR); - shadowReturn = unwrap(shadowR); - return noMod != 0; - }; -} - -void EnzymeRegisterDiffUseCallHandler(char *Name, - CustomFunctionDiffUse Handle) { - auto &pair = customDiffUseHandlers[Name]; - pair = [=](const CallInst *CI, const GradientUtils *gutils, const Value *arg, - bool isshadow, DerivativeMode mode, bool &useDefault) -> bool { - uint8_t useDefaultC = 0; - uint8_t noMod = Handle(wrap(CI), gutils, wrap(arg), isshadow, - (CDerivativeMode)(mode), &useDefaultC); - useDefault = useDefaultC != 0; - return noMod != 0; - }; -} - -uint8_t EnzymeGradientUtilsGetRuntimeActivity(GradientUtils *gutils) { - return gutils->runtimeActivity; -} - -uint64_t EnzymeGradientUtilsGetWidth(GradientUtils *gutils) { - return gutils->getWidth(); -} - -LLVMTypeRef EnzymeGradientUtilsGetShadowType(GradientUtils *gutils, - LLVMTypeRef T) { - return wrap(gutils->getShadowType(unwrap(T))); -} - -LLVMTypeRef EnzymeGetShadowType(uint64_t width, LLVMTypeRef T) { - return wrap(GradientUtils::getShadowType(unwrap(T), width)); -} - -LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils, - LLVMValueRef val) { - return wrap(gutils->getNewFromOriginal(unwrap(val))); -} - -CDerivativeMode EnzymeGradientUtilsGetMode(GradientUtils *gutils) { - return (CDerivativeMode)gutils->mode; -} - -CDIFFE_TYPE -EnzymeGradientUtilsGetDiffeType(GradientUtils *G, LLVMValueRef oval, - uint8_t foreignFunction) { - return (CDIFFE_TYPE)(G->getDiffeType(unwrap(oval), foreignFunction != 0)); -} - -CDIFFE_TYPE -EnzymeGradientUtilsGetReturnDiffeType(GradientUtils *G, LLVMValueRef oval, - uint8_t *needsPrimal, - uint8_t *needsShadow, - CDerivativeMode mode) { - bool needsPrimalB; - bool needsShadowB; - auto res = (CDIFFE_TYPE)(G->getReturnDiffeType( - unwrap(oval), &needsPrimalB, &needsShadowB, (DerivativeMode)mode)); - if (needsPrimal) - *needsPrimal = needsPrimalB; - if (needsShadow) - *needsShadow = needsShadowB; - return res; -} - -void EnzymeGradientUtilsSetDebugLocFromOriginal(GradientUtils *gutils, - LLVMValueRef val, - LLVMValueRef orig) { - return cast(unwrap(val)) - ->setDebugLoc(gutils->getNewFromOriginal( - cast(unwrap(orig))->getDebugLoc())); -} - -LLVMValueRef EnzymeInsertValue(LLVMBuilderRef B, LLVMValueRef val, - LLVMValueRef val2, unsigned *sz, int64_t length, - const char *name) { - return wrap(unwrap(B)->CreateInsertValue( - unwrap(val), unwrap(val2), ArrayRef(sz, sz + length), name)); -} - -LLVMValueRef EnzymeGradientUtilsLookup(GradientUtils *gutils, LLVMValueRef val, - LLVMBuilderRef B) { - return wrap(gutils->lookupM(unwrap(val), *unwrap(B))); -} - -LLVMValueRef EnzymeGradientUtilsInvertPointer(GradientUtils *gutils, - LLVMValueRef val, - LLVMBuilderRef B) { - return wrap(gutils->invertPointerM(unwrap(val), *unwrap(B))); -} - -LLVMValueRef EnzymeGradientUtilsDiffe(DiffeGradientUtils *gutils, - LLVMValueRef val, LLVMBuilderRef B) { - return wrap(gutils->diffe(unwrap(val), *unwrap(B))); -} - -void EnzymeGradientUtilsAddToDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, - LLVMValueRef diffe, LLVMBuilderRef B, - LLVMTypeRef T) { - gutils->addToDiffe(unwrap(val), unwrap(diffe), *unwrap(B), unwrap(T)); -} - -void EnzymeGradientUtilsAddToInvertedPointerDiffe( - DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal, - LLVMTypeRef addingType, unsigned start, unsigned size, LLVMValueRef origptr, - LLVMValueRef dif, LLVMBuilderRef BuilderM, unsigned align, - LLVMValueRef mask) { - MaybeAlign align2; - if (align) - align2 = MaybeAlign(align); - auto inst = cast_or_null(unwrap(orig)); - gutils->addToInvertedPtrDiffe(inst, unwrap(origVal), unwrap(addingType), - start, size, unwrap(origptr), unwrap(dif), - *unwrap(BuilderM), align2, unwrap(mask)); -} - -void EnzymeGradientUtilsAddToInvertedPointerDiffeTT( - DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal, - CTypeTreeRef vd, unsigned LoadSize, LLVMValueRef origptr, - LLVMValueRef prediff, LLVMBuilderRef BuilderM, unsigned align, - LLVMValueRef premask) { - MaybeAlign align2; - if (align) - align2 = MaybeAlign(align); - auto inst = cast_or_null(unwrap(orig)); - gutils->addToInvertedPtrDiffe(inst, unwrap(origVal), *(TypeTree *)vd, - LoadSize, unwrap(origptr), unwrap(prediff), - *unwrap(BuilderM), align2, unwrap(premask)); -} - -void EnzymeGradientUtilsSetDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, - LLVMValueRef diffe, LLVMBuilderRef B) { - gutils->setDiffe(unwrap(val), unwrap(diffe), *unwrap(B)); -} - -uint8_t EnzymeGradientUtilsIsConstantValue(GradientUtils *gutils, - LLVMValueRef val) { - return gutils->isConstantValue(unwrap(val)); -} - -uint8_t EnzymeGradientUtilsIsConstantInstruction(GradientUtils *gutils, - LLVMValueRef val) { - return gutils->isConstantInstruction(cast(unwrap(val))); -} - -LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) { - return wrap(gutils->inversionAllocs); -} - -uint8_t EnzymeGradientUtilsGetUncacheableArgs(GradientUtils *gutils, - LLVMValueRef orig, uint8_t *data, - uint64_t size) { - if (gutils->mode == DerivativeMode::ForwardMode || - gutils->mode == DerivativeMode::ForwardModeError) - return 0; - - if (!gutils->overwritten_args_map_ptr) - return 0; - - CallInst *call = cast(unwrap(orig)); - - assert(gutils->overwritten_args_map_ptr); - auto found = gutils->overwritten_args_map_ptr->find(call); - if (found == gutils->overwritten_args_map_ptr->end()) { - llvm::errs() << " oldFunc " << *gutils->oldFunc << "\n"; - for (auto &pair : *gutils->overwritten_args_map_ptr) { - llvm::errs() << " + " << *pair.first << "\n"; - } - llvm::errs() << " could not find call orig in overwritten_args_map_ptr " - << *call << "\n"; - } - assert(found != gutils->overwritten_args_map_ptr->end()); - - const std::vector &overwritten_args = found->second.second; - - if (size != overwritten_args.size()) { - llvm::errs() << " orig: " << *call << "\n"; - llvm::errs() << " size: " << size - << " overwritten_args.size(): " << overwritten_args.size() - << "\n"; - } - assert(size == overwritten_args.size()); - for (uint64_t i = 0; i < size; i++) { - data[i] = overwritten_args[i]; - } - return 1; -} - -CTypeTreeRef EnzymeGradientUtilsAllocAndGetTypeTree(GradientUtils *gutils, - LLVMValueRef val) { - auto v = unwrap(val); - TypeTree TT = gutils->TR.query(v); - TypeTree *pTT = new TypeTree(TT); - return (CTypeTreeRef)pTT; -} - -void EnzymeGradientUtilsDumpTypeResults(GradientUtils *gutils) { - gutils->TR.dump(); -} - -void EnzymeGradientUtilsSubTransferHelper( - GradientUtils *gutils, CDerivativeMode mode, LLVMTypeRef secretty, - uint64_t intrinsic, uint64_t dstAlign, uint64_t srcAlign, uint64_t offset, - uint8_t dstConstant, LLVMValueRef shadow_dst, uint8_t srcConstant, - LLVMValueRef shadow_src, LLVMValueRef length, LLVMValueRef isVolatile, - LLVMValueRef MTI, uint8_t allowForward, uint8_t shadowsLookedUp) { - auto orig = unwrap(MTI); - assert(orig); - SubTransferHelper(gutils, (DerivativeMode)mode, unwrap(secretty), - (Intrinsic::ID)intrinsic, (unsigned)dstAlign, - (unsigned)srcAlign, (unsigned)offset, (bool)dstConstant, - unwrap(shadow_dst), (bool)srcConstant, unwrap(shadow_src), - unwrap(length), unwrap(isVolatile), cast(orig), - (bool)allowForward, (bool)shadowsLookedUp); -} - -LLVMValueRef EnzymeCreateForwardDiff( - EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, - LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, - size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, - CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity, - unsigned width, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, - uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, - size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented) { - SmallVector nconstant_args((DIFFE_TYPE *)constant_args, - (DIFFE_TYPE *)constant_args + - constant_args_size); - std::vector overwritten_args; - assert(overwritten_args_size == cast(unwrap(todiff))->arg_size()); - for (uint64_t i = 0; i < overwritten_args_size; i++) { - overwritten_args.push_back(_overwritten_args[i]); - } - return wrap(eunwrap(Logic).CreateForwardDiff( - RequestContext(cast_or_null(unwrap(request_req)), - unwrap(request_ip)), - cast(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args, - eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory, - runtimeActivity, width, unwrap(additionalArg), - eunwrap(typeInfo, cast(unwrap(todiff))), - subsequent_calls_may_write, overwritten_args, eunwrap(augmented))); -} -LLVMValueRef EnzymeCreatePrimalAndGradient( - EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, - LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, - size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, - uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity, - unsigned width, uint8_t freeMemory, LLVMTypeRef additionalArg, - uint8_t forceAnonymousTape, CFnTypeInfo typeInfo, - uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, - size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented, - uint8_t AtomicAdd) { - std::vector nconstant_args((DIFFE_TYPE *)constant_args, - (DIFFE_TYPE *)constant_args + - constant_args_size); - std::vector overwritten_args; - assert(overwritten_args_size == cast(unwrap(todiff))->arg_size()); - for (uint64_t i = 0; i < overwritten_args_size; i++) { - overwritten_args.push_back(_overwritten_args[i]); - } - return wrap(eunwrap(Logic).CreatePrimalAndGradient( - RequestContext(cast_or_null(unwrap(request_req)), - unwrap(request_ip)), - (ReverseCacheKey){ - .todiff = cast(unwrap(todiff)), - .retType = (DIFFE_TYPE)retType, - .constant_args = nconstant_args, - .subsequent_calls_may_write = (bool)subsequent_calls_may_write, - .overwritten_args = overwritten_args, - .returnUsed = (bool)returnValue, - .shadowReturnUsed = (bool)dretUsed, - .mode = (DerivativeMode)mode, - .width = width, - .freeMemory = (bool)freeMemory, - .AtomicAdd = (bool)AtomicAdd, - .additionalType = unwrap(additionalArg), - .forceAnonymousTape = (bool)forceAnonymousTape, - .typeInfo = eunwrap(typeInfo, cast(unwrap(todiff))), - .runtimeActivity = (bool)runtimeActivity}, - eunwrap(TA), eunwrap(augmented))); -} -EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( - EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, - LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, - size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnUsed, - uint8_t shadowReturnUsed, CFnTypeInfo typeInfo, - uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, - size_t overwritten_args_size, uint8_t forceAnonymousTape, - uint8_t runtimeActivity, unsigned width, uint8_t AtomicAdd) { - - SmallVector nconstant_args((DIFFE_TYPE *)constant_args, - (DIFFE_TYPE *)constant_args + - constant_args_size); - std::vector overwritten_args; - assert(overwritten_args_size == cast(unwrap(todiff))->arg_size()); - for (uint64_t i = 0; i < overwritten_args_size; i++) { - overwritten_args.push_back(_overwritten_args[i]); - } - return ewrap(eunwrap(Logic).CreateAugmentedPrimal( - RequestContext(cast_or_null(unwrap(request_req)), - unwrap(request_ip)), - cast(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args, - eunwrap(TA), returnUsed, shadowReturnUsed, - eunwrap(typeInfo, cast(unwrap(todiff))), - subsequent_calls_may_write, overwritten_args, forceAnonymousTape, - runtimeActivity, width, AtomicAdd)); -} - -LLVMValueRef EnzymeCreateBatch(EnzymeLogicRef Logic, LLVMValueRef request_req, - LLVMBuilderRef request_ip, LLVMValueRef tobatch, - unsigned width, CBATCH_TYPE *arg_types, - size_t arg_types_size, CBATCH_TYPE retType) { - - return wrap(eunwrap(Logic).CreateBatch( - RequestContext(cast_or_null(unwrap(request_req)), - unwrap(request_ip)), - cast(unwrap(tobatch)), width, - ArrayRef((BATCH_TYPE *)arg_types, - (BATCH_TYPE *)arg_types + arg_types_size), - (BATCH_TYPE)retType)); -} - -LLVMValueRef EnzymeCreateTrace( - EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, - LLVMValueRef totrace, LLVMValueRef *sample_functions, - size_t sample_functions_size, LLVMValueRef *observe_functions, - size_t observe_functions_size, const char *active_random_variables[], - size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff, - EnzymeTraceInterfaceRef interface) { - - SmallPtrSet SampleFunctions; - for (size_t i = 0; i < sample_functions_size; i++) { - SampleFunctions.insert(cast(unwrap(sample_functions[i]))); - } - - SmallPtrSet ObserveFunctions; - for (size_t i = 0; i < observe_functions_size; i++) { - ObserveFunctions.insert(cast(unwrap(observe_functions[i]))); - } - - StringSet<> ActiveRandomVariables; - for (size_t i = 0; i < active_random_variables_size; i++) { - ActiveRandomVariables.insert(active_random_variables[i]); - } - - return wrap(eunwrap(Logic).CreateTrace( - RequestContext(cast_or_null(unwrap(request_req)), - unwrap(request_ip)), - cast(unwrap(totrace)), SampleFunctions, ObserveFunctions, - ActiveRandomVariables, (ProbProgMode)mode, (bool)autodiff, - eunwrap(interface))); -} - -LLVMValueRef -EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret) { - auto AR = (AugmentedReturn *)ret; - return wrap(AR->fn); -} - -LLVMTypeRef -EnzymeExtractUnderlyingTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret) { - auto AR = (AugmentedReturn *)ret; - return wrap(AR->tapeType); -} - -LLVMTypeRef -EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret) { - auto AR = (AugmentedReturn *)ret; - auto found = AR->returns.find(AugmentedStruct::Tape); - if (found == AR->returns.end()) { - return wrap((Type *)nullptr); - } - if (found->second == -1) { - return wrap(AR->fn->getReturnType()); - } - return wrap( - cast(AR->fn->getReturnType())->getTypeAtIndex(found->second)); -} -void EnzymeExtractReturnInfo(EnzymeAugmentedReturnPtr ret, int64_t *data, - uint8_t *existed, size_t len) { - assert(len == 3); - auto AR = (AugmentedReturn *)ret; - AugmentedStruct todo[] = {AugmentedStruct::Tape, AugmentedStruct::Return, - AugmentedStruct::DifferentialReturn}; - for (size_t i = 0; i < len; i++) { - auto found = AR->returns.find(todo[i]); - if (found != AR->returns.end()) { - existed[i] = true; - data[i] = (int64_t)found->second; - } else { - existed[i] = false; - } - } -} - -static MDNode *extractMDNode(MetadataAsValue *MAV) { - Metadata *MD = MAV->getMetadata(); - assert((isa(MD) || isa(MD)) && - "Expected a metadata node or a canonicalized constant"); - - if (MDNode *N = dyn_cast(MD)) - return N; - - return MDNode::get(MAV->getContext(), MD); -} - -CTypeTreeRef EnzymeTypeTreeFromMD(LLVMValueRef Val) { - TypeTree *Ret = new TypeTree(); - MDNode *N = Val ? extractMDNode(unwrap(Val)) : nullptr; - Ret->insertFromMD(N); - return (CTypeTreeRef)N; -} - -LLVMValueRef EnzymeTypeTreeToMD(CTypeTreeRef CTR, LLVMContextRef ctx) { - auto MD = ((TypeTree *)CTR)->toMD(*unwrap(ctx)); - return wrap(MetadataAsValue::get(MD->getContext(), MD)); -} - -CTypeTreeRef EnzymeNewTypeTree() { return (CTypeTreeRef)(new TypeTree()); } -CTypeTreeRef EnzymeNewTypeTreeCT(CConcreteType CT, LLVMContextRef ctx) { - return (CTypeTreeRef)(new TypeTree(eunwrap(CT, *unwrap(ctx)))); -} -CTypeTreeRef EnzymeNewTypeTreeTR(CTypeTreeRef CTR) { - return (CTypeTreeRef)(new TypeTree(*(TypeTree *)(CTR))); -} -void EnzymeFreeTypeTree(CTypeTreeRef CTT) { delete (TypeTree *)CTT; } -uint8_t EnzymeSetTypeTree(CTypeTreeRef dst, CTypeTreeRef src) { - return *(TypeTree *)dst = *(TypeTree *)src; -} -uint8_t EnzymeMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src) { - return ((TypeTree *)dst)->orIn(*(TypeTree *)src, /*PointerIntSame*/ false); -} -uint8_t EnzymeCheckedMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src, - uint8_t *legalP) { - bool legal = true; - bool res = - ((TypeTree *)dst) - ->checkedOrIn(*(TypeTree *)src, /*PointerIntSame*/ false, legal); - *legalP = legal; - return res; -} - -void EnzymeTypeTreeOnlyEq(CTypeTreeRef CTT, int64_t x) { - // TODO only inst - *(TypeTree *)CTT = ((TypeTree *)CTT)->Only(x, nullptr); -} -void EnzymeTypeTreeData0Eq(CTypeTreeRef CTT) { - *(TypeTree *)CTT = ((TypeTree *)CTT)->Data0(); -} - -void EnzymeTypeTreeLookupEq(CTypeTreeRef CTT, int64_t size, const char *dl) { - *(TypeTree *)CTT = ((TypeTree *)CTT)->Lookup(size, DataLayout(dl)); -} -void EnzymeTypeTreeCanonicalizeInPlace(CTypeTreeRef CTT, int64_t size, - const char *dl) { - ((TypeTree *)CTT)->CanonicalizeInPlace(size, DataLayout(dl)); -} - -CConcreteType EnzymeTypeTreeInner0(CTypeTreeRef CTT) { - return ewrap(((TypeTree *)CTT)->Inner0()); -} - -void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef CTT, const char *datalayout, - int64_t offset, int64_t maxSize, - uint64_t addOffset) { - DataLayout DL(datalayout); - *(TypeTree *)CTT = - ((TypeTree *)CTT)->ShiftIndices(DL, offset, maxSize, addOffset); -} -const char *EnzymeTypeTreeToString(CTypeTreeRef src) { - std::string tmp = ((TypeTree *)src)->str(); - char *cstr = new char[tmp.length() + 1]; - std::strcpy(cstr, tmp.c_str()); - - return cstr; -} - -// TODO deprecated -void EnzymeTypeTreeToStringFree(const char *cstr) { delete[] cstr; } - -const char *EnzymeTypeAnalyzerToString(void *src) { - auto TA = (TypeAnalyzer *)src; - std::string str; - raw_string_ostream ss(str); - TA->dump(ss); - ss.str(); - char *cstr = new char[str.length() + 1]; - std::strcpy(cstr, str.c_str()); - return cstr; -} - -const char *EnzymeGradientUtilsInvertedPointersToString(GradientUtils *gutils, - void *src) { - std::string str; - raw_string_ostream ss(str); - for (auto z : gutils->invertedPointers) { - ss << "available inversion for " << *z.first << " of " << *z.second << "\n"; - } - ss.str(); - char *cstr = new char[str.length() + 1]; - std::strcpy(cstr, str.c_str()); - return cstr; -} - -LLVMValueRef EnzymeGradientUtilsCallWithInvertedBundles( - GradientUtils *gutils, LLVMValueRef func, LLVMTypeRef funcTy, - LLVMValueRef *args_vr, uint64_t args_size, LLVMValueRef orig_vr, - CValueType *valTys, uint64_t valTys_size, LLVMBuilderRef B, - uint8_t lookup) { - auto orig = cast(unwrap(orig_vr)); - - ArrayRef ar((ValueType *)valTys, valTys_size); - - IRBuilder<> &BR = *unwrap(B); - - auto Defs = gutils->getInvertedBundles(orig, ar, BR, lookup != 0); - - SmallVector args; - for (size_t i = 0; i < args_size; i++) { - args.push_back(unwrap(args_vr[i])); - } - - auto callval = unwrap(func); - - auto res = - BR.CreateCall(cast(unwrap(funcTy)), callval, args, Defs); - return wrap(res); -} - -void EnzymeStringFree(const char *cstr) { delete[] cstr; } - -void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2, - LLVMBuilderRef B) { - Instruction *I1 = cast(unwrap(inst1)); - Instruction *I2 = cast(unwrap(inst2)); - if (I1 != I2) { - if (B != nullptr) { - IRBuilder<> &BR = *unwrap(B); - if (I1->getIterator() == BR.GetInsertPoint()) { - if (I2->getNextNode() == nullptr) - BR.SetInsertPoint(I1->getParent()); - else - BR.SetInsertPoint(I1->getNextNode()); - } - } - I1->moveBefore(I2); - } -} - -void EnzymeSetStringMD(LLVMValueRef Inst, const char *Kind, LLVMValueRef Val) { - MDNode *N = Val ? extractMDNode(unwrap(Val)) : nullptr; - Value *V = unwrap(Inst); - if (auto I = dyn_cast(V)) - I->setMetadata(Kind, N); - else - cast(V)->setMetadata(Kind, N); -} - -LLVMValueRef EnzymeGetStringMD(LLVMValueRef Inst, const char *Kind) { - auto *I = unwrap(Inst); - assert(I && "Expected instruction"); - if (auto *MD = I->getMetadata(Kind)) - return wrap(MetadataAsValue::get(I->getContext(), MD)); - return nullptr; -} - -void EnzymeSetMustCache(LLVMValueRef inst1) { - Instruction *I1 = cast(unwrap(inst1)); - I1->setMetadata("enzyme_mustcache", MDNode::get(I1->getContext(), {})); -} - -uint8_t EnzymeHasFromStack(LLVMValueRef inst1) { - Instruction *I1 = cast(unwrap(inst1)); - return hasMetadata(I1, "enzyme_fromstack") != 0; -} - -void EnzymeCloneFunctionDISubprogramInto(LLVMValueRef NF, LLVMValueRef F) { - auto &OldFunc = *cast(unwrap(F)); - auto &NewFunc = *cast(unwrap(NF)); - auto OldSP = OldFunc.getSubprogram(); - if (!OldSP) - return; - DIBuilder DIB(*OldFunc.getParent(), /*AllowUnresolved=*/false, - OldSP->getUnit()); - auto SPType = DIB.createSubroutineType(DIB.getOrCreateTypeArray({})); - DISubprogram::DISPFlags SPFlags = DISubprogram::SPFlagDefinition | - DISubprogram::SPFlagOptimized | - DISubprogram::SPFlagLocalToUnit; - auto NewSP = DIB.createFunction( - OldSP->getUnit(), NewFunc.getName(), NewFunc.getName(), OldSP->getFile(), - /*LineNo=*/0, SPType, /*ScopeLine=*/0, DINode::FlagZero, SPFlags); - NewFunc.setSubprogram(NewSP); - DIB.finalizeSubprogram(NewSP); - return; -} - -void EnzymeReplaceFunctionImplementation(LLVMModuleRef M) { - ReplaceFunctionImplementation(*unwrap(M)); -} - -void EnzymeDumpModuleRef(LLVMModuleRef M) { - llvm::errs() << *unwrap(M) << "\n"; -} - -static bool runAttributorOnFunctions(InformationCache &InfoCache, - SetVector &Functions, - AnalysisGetter &AG, - CallGraphUpdater &CGUpdater, - bool DeleteFns, bool IsModulePass) { - if (Functions.empty()) - return false; - - // Create an Attributor and initially empty information cache that is filled - // while we identify default attribute opportunities. - AttributorConfig AC(CGUpdater); - AC.RewriteSignatures = false; - AC.IsModulePass = IsModulePass; - AC.DeleteFns = DeleteFns; - Attributor A(Functions, InfoCache, AC); - - for (Function *F : Functions) { - // Populate the Attributor with abstract attribute opportunities in the - // function and the information cache with IR information. - A.identifyDefaultAbstractAttributes(*F); - } - - ChangeStatus Changed = A.run(); - - return Changed == ChangeStatus::CHANGED; -} -struct MyAttributorLegacyPass : public ModulePass { - static char ID; - - MyAttributorLegacyPass() : ModulePass(ID) {} - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - AnalysisGetter AG; - SetVector Functions; - for (Function &F : M) - Functions.insert(&F); - - CallGraphUpdater CGUpdater; - BumpPtrAllocator Allocator; - InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr); - return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater, - /* DeleteFns*/ true, - /* IsModulePass */ true); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - // FIXME: Think about passes we will preserve and add them here. - AU.addRequired(); - } -}; -extern "C++" char MyAttributorLegacyPass::ID = 0; -void EnzymeAddAttributorLegacyPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(new MyAttributorLegacyPass()); -} - -LLVMMetadataRef EnzymeMakeNonConstTBAA(LLVMMetadataRef MD) { - auto M = cast(unwrap(MD)); - if (M->getNumOperands() != 4) - return MD; - auto CAM = dyn_cast(M->getOperand(3)); - if (!CAM) - return MD; - if (!CAM->getValue()->isOneValue()) - return MD; - SmallVector MDs; - for (auto &M : M->operands()) - MDs.push_back(M); - MDs[3] = - ConstantAsMetadata::get(ConstantInt::get(CAM->getValue()->getType(), 0)); - return wrap(MDNode::get(M->getContext(), MDs)); -} -void EnzymeCopyMetadata(LLVMValueRef inst1, LLVMValueRef inst2) { - cast(unwrap(inst1)) - ->copyMetadata(*cast(unwrap(inst2))); -} -LLVMMetadataRef EnzymeAnonymousAliasScopeDomain(const char *str, - LLVMContextRef ctx) { - MDBuilder MDB(*unwrap(ctx)); - MDNode *scope = MDB.createAnonymousAliasScopeDomain(str); - return wrap(scope); -} -LLVMMetadataRef EnzymeAnonymousAliasScope(LLVMMetadataRef domain, - const char *str) { - auto dom = cast(unwrap(domain)); - MDBuilder MDB(dom->getContext()); - MDNode *scope = MDB.createAnonymousAliasScope(dom, str); - return wrap(scope); -} -uint8_t EnzymeLowerSparsification(LLVMValueRef F, uint8_t replaceAll) { - return LowerSparsification(cast(unwrap(F)), replaceAll != 0); -} - -void EnzymeAttributeKnownFunctions(LLVMValueRef FC) { - attributeKnownFunctions(*cast(unwrap(FC))); -} - -void EnzymeSetCalledFunction(LLVMValueRef C_CI, LLVMValueRef C_F, - uint64_t *argrem, uint64_t num_argrem) { - auto CI = cast(unwrap(C_CI)); - auto F = cast(unwrap(C_F)); - auto Attrs = CI->getAttributes(); - AttributeList NewAttrs; - - if (CI->getType() == F->getReturnType()) { - for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::ReturnIndex, attr); - } - for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FunctionIndex, attr); - - size_t argremsz = 0; - size_t nexti = 0; - SmallVector vals; - for (size_t i = 0, end = CI->arg_size(); i < end; i++) { - if (argremsz < num_argrem) { - if (i == argrem[argremsz]) { - argremsz++; - continue; - } - } - for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + nexti, attr); - vals.push_back(CI->getArgOperand(i)); - nexti++; - } - assert(argremsz == num_argrem); - - IRBuilder<> B(CI); - SmallVector Bundles; - for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I) - Bundles.emplace_back(CI->getOperandBundleAt(I)); - auto NC = B.CreateCall(F, vals, Bundles); - NC->setAttributes(NewAttrs); - NC->copyMetadata(*CI); - - if (CI->getType() == F->getReturnType()) - CI->replaceAllUsesWith(NC); - - if (!NC->getType()->isVoidTy()) - NC->takeName(CI); - NC->setCallingConv(CI->getCallingConv()); - CI->eraseFromParent(); -} - -// clones a function to now miss the return or args -LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC, - uint8_t keepReturnU, - uint64_t *argrem, - uint64_t num_argrem) { - auto F = cast(unwrap(FC)); - auto FT = F->getFunctionType(); - bool keepReturn = keepReturnU != 0; - - size_t argremsz = 0; - size_t nexti = 0; - SmallVector types; - auto Attrs = F->getAttributes(); - AttributeList NewAttrs; - for (size_t i = 0, end = FT->getNumParams(); i < end; i++) { - if (argremsz < num_argrem) { - if (i == argrem[argremsz]) { - argremsz++; - continue; - } - } - for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + nexti, attr); - types.push_back(F->getFunctionType()->getParamType(i)); - nexti++; - } - if (keepReturn) { - for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::ReturnIndex, attr); - } - for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FunctionIndex, attr); - - FunctionType *FTy = FunctionType::get( - keepReturn ? F->getReturnType() : Type::getVoidTy(F->getContext()), types, - FT->isVarArg()); - - // Create the new function - Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(), - F->getName(), F->getParent()); - - ValueToValueMapTy VMap; - // Loop over the arguments, copying the names of the mapped arguments over... - nexti = 0; - argremsz = 0; - Function::arg_iterator DestI = NewF->arg_begin(); - for (const Argument &I : F->args()) { - if (argremsz < num_argrem) { - if (I.getArgNo() == argrem[argremsz]) { - VMap[&I] = UndefValue::get(I.getType()); - argremsz++; - continue; - } - } - DestI->setName(I.getName()); // Copy the name over... - VMap[&I] = &*DestI++; // Add mapping to VMap - } - - SmallVector Returns; // Ignore returns cloned. - CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); - - if (!keepReturn) { - for (auto &B : *NewF) { - if (auto RI = dyn_cast(B.getTerminator())) { - IRBuilder<> B(RI); - auto NRI = B.CreateRetVoid(); - NRI->copyMetadata(*RI); - RI->eraseFromParent(); - } - } - } - NewF->setAttributes(NewAttrs); - if (!keepReturn) - for (auto &Arg : NewF->args()) - Arg.removeAttr(Attribute::Returned); - SmallVector, 1> MD; - F->getAllMetadata(MD); - for (auto pair : MD) - if (pair.first != LLVMContext::MD_dbg) - NewF->addMetadata(pair.first, *pair.second); - NewF->takeName(F); - NewF->setCallingConv(F->getCallingConv()); - if (!keepReturn) - NewF->addFnAttr("enzyme_retremove", ""); - - if (num_argrem) { - SmallVector previdx; - if (Attrs.hasAttribute(AttributeList::FunctionIndex, "enzyme_parmremove")) { - auto attr = - Attrs.getAttribute(AttributeList::FunctionIndex, "enzyme_parmremove"); - auto prevstr = attr.getValueAsString(); - SmallVector sub; - prevstr.split(sub, ","); - for (auto s : sub) { - uint64_t ival; - bool b = s.getAsInteger(10, ival); - (void)b; - assert(!b); - previdx.push_back(ival); - } - } - SmallVector nextidx; - for (size_t i = 0; i < num_argrem; i++) { - auto val = argrem[i]; - nextidx.push_back(val); - } - - size_t prevcnt = 0; - size_t nextcnt = 0; - SmallVector out; - while (prevcnt < previdx.size() && nextcnt < nextidx.size()) { - if (previdx[prevcnt] <= nextidx[nextcnt] + prevcnt) { - out.push_back(previdx[prevcnt]); - prevcnt++; - } else { - out.push_back(nextidx[nextcnt] + prevcnt); - nextcnt++; - } - } - while (prevcnt < previdx.size()) { - out.push_back(previdx[prevcnt]); - prevcnt++; - } - while (nextcnt < nextidx.size()) { - out.push_back(nextidx[nextcnt] + prevcnt); - nextcnt++; - } - - std::string remstr; - for (auto arg : out) { - if (remstr.size()) - remstr += ","; - remstr += std::to_string(arg); - } - - NewF->addFnAttr("enzyme_parmremove", remstr); - } - return wrap(NewF); -} -LLVMTypeRef EnzymeAllocaType(LLVMValueRef V) { - return wrap(cast(unwrap(V))->getAllocatedType()); -} -LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r, - LLVMTypeRef T_r) { - IRBuilder<> &B = *unwrap(B_r); - auto T = cast(unwrap(T_r)); - auto width = T->getBitWidth(); - auto uw = unwrap(V_r); - GEPOperator *gep = isa(uw) - ? cast(cast(uw)) - : cast(cast(uw)); - auto &DL = B.GetInsertBlock()->getParent()->getParent()->getDataLayout(); - -#if LLVM_VERSION_MAJOR >= 20 - SmallMapVector VariableOffsets; -#else - MapVector VariableOffsets; -#endif - APInt Offset(width, 0); - bool success = collectOffset(gep, DL, width, VariableOffsets, Offset); - (void)success; - assert(success); - Value *start = ConstantInt::get(T, Offset); - for (auto &pair : VariableOffsets) - start = B.CreateAdd( - start, B.CreateMul(pair.first, ConstantInt::get(T, pair.second))); - return wrap(start); -} -} - -static size_t num_rooting(llvm::Type *T, llvm::Function *F) { - CountTrackedPointers tracked(T); - if (tracked.derived) { - llvm::errs() << *F << "\n"; - llvm::errs() << "Invalid Derived Type: " << *T << "\n"; - } - assert(!tracked.derived); - if (tracked.count != 0 && !tracked.all) - return tracked.count; - return 0; -} - -extern "C" { - -void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) { - auto F = cast(unwrap(F_C)); - if (F->empty()) - return; - auto RT = F->getReturnType(); - auto FT = F->getFunctionType(); - auto Attrs = F->getAttributes(); - - AttributeList NewAttrs; - SmallVector types; - SmallSet changed; - for (auto pair : llvm::enumerate(FT->params())) { - auto T = pair.value(); - auto i = pair.index(); - bool sretv = false; - for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) { - if (attr.isStringAttribute() && - attr.getKindAsString() == "enzyme_sret_v") { - sretv = true; - } else { - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + types.size(), attr); - } - } - if (auto AT = dyn_cast(T)) { - if (auto PT = dyn_cast(AT->getElementType())) { - auto AS = PT->getAddressSpace(); - if (AS == 11 || AS == 12 || AS == 13 || sretv) { - for (unsigned i = 0; i < AT->getNumElements(); i++) { - if (sretv) { - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + types.size(), - Attribute::get(F->getContext(), "enzyme_sret")); - } - types.push_back(PT); - } - changed.insert(i); - continue; - } - } - } - types.push_back(T); - } - if (changed.size() == 0) - return; - - for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FunctionIndex, attr); - - for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::ReturnIndex, attr); - - FunctionType *FTy = - FunctionType::get(FT->getReturnType(), types, FT->isVarArg()); - - // Create the new function - Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(), - F->getName(), F->getParent()); - - ValueToValueMapTy VMap; - // Loop over the arguments, copying the names of the mapped arguments over... - Function::arg_iterator DestI = NewF->arg_begin(); - - // To handle the deleted args, it needs to be replaced by a non-arg operand. - // This map contains the temporary phi nodes corresponding - SmallVector toInsert; - for (Argument &I : F->args()) { - auto T = I.getType(); - if (auto AT = dyn_cast(T)) { - if (changed.count(I.getArgNo())) { - Value *V = UndefValue::get(T); - for (unsigned i = 0; i < AT->getNumElements(); i++) { - DestI->setName(I.getName() + "." + - std::to_string(i)); // Copy the name over... - unsigned idx[1] = {i}; - auto IV = InsertValueInst::Create(V, (llvm::Value *)&*DestI++, idx); - toInsert.push_back(IV); - V = IV; - } - VMap[&I] = V; - continue; - } - } - DestI->setName(I.getName()); // Copy the name over... - VMap[&I] = &*DestI++; // Add mapping to VMap - } - - SmallVector Returns; // Ignore returns cloned. - CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); - - { - IRBuilder<> EB(&*NewF->getEntryBlock().begin()); - for (auto I : toInsert) - EB.Insert(I); - } - - SmallVector callers; - for (auto U : F->users()) { - auto CI = dyn_cast(U); - assert(CI); - assert(CI->getCalledFunction() == F); - callers.push_back(CI); - } - - for (auto CI : callers) { - auto Attrs = CI->getAttributes(); - AttributeList NewAttrs; - IRBuilder<> B(CI); - - for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FunctionIndex, attr); - - for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::ReturnIndex, attr); - - SmallVector vals; - for (size_t j = 0, end = CI->arg_size(); j < end; j++) { - - auto T = CI->getArgOperand(j)->getType(); - if (auto AT = dyn_cast(T)) { - if (isa(AT->getElementType())) { - if (changed.count(j)) { - bool sretv = false; - for (auto attr : - Attrs.getAttributes(AttributeList::FirstArgIndex + j)) { - if (attr.isStringAttribute() && - attr.getKindAsString() == "enzyme_sret_v") { - sretv = true; - } - } - for (unsigned i = 0; i < AT->getNumElements(); i++) { - if (sretv) - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + vals.size(), - Attribute::get(F->getContext(), "enzyme_sret")); - vals.push_back( - GradientUtils::extractMeta(B, CI->getArgOperand(j), i)); - } - continue; - } - } - } - - for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + j)) { - if (attr.isStringAttribute() && - attr.getKindAsString() == "enzyme_sret_v") { - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + vals.size(), - Attribute::get(F->getContext(), "enzyme_sret")); - } else { - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + vals.size(), - attr); - } - } - - vals.push_back(CI->getArgOperand(j)); - } - - SmallVector Bundles; - for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I) - Bundles.emplace_back(CI->getOperandBundleAt(I)); - auto NC = B.CreateCall(NewF, vals, Bundles); - NC->setAttributes(NewAttrs); - - SmallVector, 4> TheMDs; - CI->getAllMetadataOtherThanDebugLoc(TheMDs); - SmallVector toCopy; - for (auto pair : TheMDs) - toCopy.push_back(pair.first); - if (!toCopy.empty()) - NC->copyMetadata(*CI, toCopy); - NC->setDebugLoc(CI->getDebugLoc()); - - if (!RT->isVoidTy()) { - NC->takeName(CI); - CI->replaceAllUsesWith(NC); - } - - NC->setCallingConv(CI->getCallingConv()); - CI->eraseFromParent(); - } - NewF->setAttributes(NewAttrs); - SmallVector, 1> MD; - F->getAllMetadata(MD); - for (auto pair : MD) - if (pair.first != LLVMContext::MD_dbg) - NewF->addMetadata(pair.first, *pair.second); - NewF->takeName(F); - NewF->setCallingConv(F->getCallingConv()); - F->eraseFromParent(); -} - -void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) { - auto F = cast(unwrap(F_C)); - if (F->empty()) - return; - auto RT = F->getReturnType(); - std::set srets; - std::set enzyme_srets; - std::set enzyme_srets_v; - std::set rroots; - std::set rroots_v; - - auto FT = F->getFunctionType(); - auto Attrs = F->getAttributes(); - for (size_t i = 0, end = FT->getNumParams(); i < end; i++) { - if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, - Attribute::StructRet)) - srets.insert(i); - if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, "enzyme_sret")) - enzyme_srets.insert(i); - if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, "enzyme_sret_v")) - enzyme_srets_v.insert(i); - if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, - "enzymejl_returnRoots")) - rroots.insert(i); - if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, - "enzymejl_returnRoots_v")) - rroots_v.insert(i); - } - // Regular julia function, needing no intervention - if (srets.size() == 1) { - assert(*srets.begin() == 0); - assert(enzyme_srets.size() == 0); - assert(enzyme_srets_v.size() == 0); - assert(rroots_v.size() == 0); - if (rroots.size()) { - assert(rroots.size() == 1); - assert(*rroots.begin() == 1); - } - return; - } - // No sret/rooting, no intervention needed. - if (srets.size() == 0 && enzyme_srets.size() == 0 && - enzyme_srets_v.size() == 0 && rroots.size() == 0 && - rroots_v.size() == 0) { - return; - } - - assert(srets.size() == 0); - - SmallVector Types; - if (!RT->isVoidTy()) { - Types.push_back(RT); - } - - for (auto idx : enzyme_srets) { - llvm::Type *T = nullptr; -#if LLVM_VERSION_MAJOR >= 17 - (void)idx; - llvm_unreachable("Unhandled"); - // T = F->getParamAttribute(idx, Attribute::AttrKind::ElementType) - // .getValueAsType(); -#else - T = FT->getParamType(idx)->getPointerElementType(); -#endif - Types.push_back(T); - } - for (auto idx : enzyme_srets_v) { - llvm::Type *T = nullptr; - auto AT = cast(FT->getParamType(idx)); -#if LLVM_VERSION_MAJOR >= 17 - llvm_unreachable("Unhandled"); - // T = F->getParamAttribute(idx, Attribute::AttrKind::ElementType) - // .getValueAsType(); -#else - T = AT->getElementType()->getPointerElementType(); -#endif - for (size_t i = 0; i < AT->getNumElements(); i++) - Types.push_back(T); - } - - StructType *ST = - Types.size() <= 1 ? nullptr : StructType::get(F->getContext(), Types); - Type *sretTy = nullptr; - if (Types.size()) - sretTy = Types.size() == 1 ? Types[0] : ST; - size_t numRooting = sretTy ? num_rooting(sretTy, F) : 0; - - auto T_jlvalue = StructType::get(F->getContext(), {}); - auto T_prjlvalue = PointerType::get(T_jlvalue, AddressSpace::Tracked); - ArrayType *roots_AT = - numRooting ? ArrayType::get(T_prjlvalue, numRooting) : nullptr; - - AttributeList NewAttrs; - SmallVector types; - size_t nexti = 0; - if (sretTy) { - types.push_back(PointerType::getUnqual(sretTy)); - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + nexti, - Attribute::get(F->getContext(), Attribute::StructRet, sretTy)); - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FirstArgIndex + nexti, - Attribute::NoAlias); - nexti++; - } - if (roots_AT) { - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FirstArgIndex + nexti, - "enzymejl_returnRoots"); - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FirstArgIndex + nexti, - Attribute::NoAlias); - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FirstArgIndex + nexti, - Attribute::WriteOnly); - types.push_back(PointerType::getUnqual(roots_AT)); - nexti++; - } - for (size_t i = 0, end = FT->getNumParams(); i < end; i++) { - if (enzyme_srets.count(i) || enzyme_srets_v.count(i) || rroots.count(i) || - rroots_v.count(i)) - continue; - - for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + nexti, attr); - types.push_back(F->getFunctionType()->getParamType(i)); - nexti++; - } - for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FunctionIndex, attr); - - FunctionType *FTy = FunctionType::get(Type::getVoidTy(F->getContext()), types, - FT->isVarArg()); - - // Create the new function - auto &M = *F->getParent(); - Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(), - F->getName(), &M); - - ValueToValueMapTy VMap; - // Loop over the arguments, copying the names of the mapped arguments over... - Function::arg_iterator DestI = NewF->arg_begin(); - Argument *sret = nullptr; - if (sretTy) { - sret = &*DestI; - DestI++; - } - Argument *roots = nullptr; - if (roots_AT) { - roots = &*DestI; - DestI++; - } - // To handle the deleted args, it needs to be replaced by a non-arg operand. - // This map contains the temporary phi nodes corresponding - // - - std::map delArgMap; - for (Argument &I : F->args()) { - auto i = I.getArgNo(); - if (enzyme_srets.count(i) || enzyme_srets_v.count(i) || rroots.count(i) || - rroots_v.count(i)) { - VMap[&I] = delArgMap[i] = PHINode::Create(I.getType(), 0); - continue; - } - assert(DestI != NewF->arg_end()); - DestI->setName(I.getName()); // Copy the name over... - VMap[&I] = &*DestI++; // Add mapping to VMap - } - - SmallVector Returns; // Ignore returns cloned. - CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); - - SmallVector callers; - for (auto U : F->users()) { - auto CI = dyn_cast(U); - assert(CI); - assert(CI->getCalledFunction() == F); - callers.push_back(CI); - } - - size_t curOffset = 0; - - std::function &, Value *, size_t)> recur = - [&](IRBuilder<> &B, Value *V, size_t offset) -> size_t { - auto T = V->getType(); - if (CountTrackedPointers(T).count == 0) - return offset; - if (roots_AT == nullptr) - return offset; - if (isa(T)) { - if (isSpecialPtr(T)) { - if (!roots_AT) { - llvm::errs() << *V << " \n"; - llvm::errs() << *cast(V)->getParent()->getParent() - << " \n"; - } - assert(roots_AT); - assert(roots); - auto gep = B.CreateConstInBoundsGEP2_32(roots_AT, roots, 0, offset); - if (T != T_prjlvalue) - V = B.CreatePointerCast(V, T_prjlvalue); - B.CreateStore(V, gep); - offset++; - } - return offset; - } else if (auto ST = dyn_cast(T)) { - for (size_t i = 0; i < ST->getNumElements(); i++) { - offset = recur(B, GradientUtils::extractMeta(B, V, i), offset); - } - return offset; - } else if (auto AT = dyn_cast(T)) { - for (size_t i = 0; i < AT->getNumElements(); i++) { - offset = recur(B, GradientUtils::extractMeta(B, V, i), offset); - } - return offset; - } else if (auto VT = dyn_cast(T)) { - size_t count = VT->getElementCount().getKnownMinValue(); - for (size_t i = 0; i < count; i++) { - offset = recur(B, B.CreateExtractElement(V, i), offset); - } - return offset; - } - return offset; - }; - - size_t sretCount = 0; - if (!RT->isVoidTy()) { - for (auto &RT : Returns) { - IRBuilder<> B(RT); - Value *gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, 0) : sret; - Value *rval = RT->getReturnValue(); - B.CreateStore(rval, gep); - recur(B, rval, 0); - auto NR = B.CreateRetVoid(); - RT->eraseFromParent(); - RT = NR; - } - if (roots_AT) - curOffset = CountTrackedPointers(RT).count; - sretCount++; - } - - for (auto i : enzyme_srets) { - auto arg = delArgMap[i]; - assert(arg); - SmallVector uses; - SmallVector op; - for (auto &U : arg->uses()) { - auto I = cast(U.getUser()); - uses.push_back(I); - op.push_back(U.getOperandNo()); - } - IRBuilder<> EB(&NewF->getEntryBlock().front()); - auto gep = - ST ? EB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret; - for (size_t i = 0; i < uses.size(); i++) { - uses[i]->setOperand(op[i], gep); - } - for (auto &RT : Returns) { - IRBuilder<> B(RT); - auto val = B.CreateLoad(Types[sretCount], gep); - recur(B, val, curOffset); - } - if (roots_AT) - curOffset += CountTrackedPointers(Types[sretCount]).count; - sretCount++; - delete arg; - } - for (auto i : enzyme_srets_v) { - auto AT = cast(FT->getParamType(i)); - auto arg = delArgMap[i]; - assert(arg); - SmallVector uses; - SmallVector op; - for (auto &U : arg->uses()) { - auto I = cast(U.getUser()); - uses.push_back(I); - op.push_back(U.getOperandNo()); - } - IRBuilder<> EB(&NewF->getEntryBlock().front()); - Value *val = UndefValue::get(AT); - for (size_t j = 0; j < AT->getNumElements(); j++) { - auto gep = - ST ? EB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount + j) : sret; - val = EB.CreateInsertValue(val, gep, j); - } - for (size_t i = 0; i < uses.size(); i++) { - uses[i]->setOperand(op[i], val); - } - for (auto &RT : Returns) { - IRBuilder<> B(RT); - for (size_t j = 0; j < AT->getNumElements(); j++) { - Value *em = GradientUtils::extractMeta(B, val, j); - em = B.CreateLoad(Types[sretCount + j], em); - recur(B, em, curOffset); - } - } - if (roots_AT) - curOffset += - CountTrackedPointers(Types[sretCount]).count * AT->getNumElements(); - sretCount += AT->getNumElements(); - delete arg; - } - - for (auto i : rroots) { - auto arg = delArgMap[i]; - assert(arg); - llvm::Type *T = nullptr; -#if LLVM_VERSION_MAJOR >= 17 - llvm_unreachable("Unhandled"); - // T = F->getParamAttribute(i, Attribute::AttrKind::ElementType) - // .getValueAsType(); -#else - T = FT->getParamType(i)->getPointerElementType(); -#endif - IRBuilder<> EB(&NewF->getEntryBlock().front()); - auto AL = EB.CreateAlloca(T, 0, "stack_roots"); - arg->replaceAllUsesWith(AL); - delete arg; - } - for (auto i : rroots_v) { - auto arg = delArgMap[i]; - assert(arg); - auto AT = cast(FT->getParamType(i)); - llvm::Type *T = nullptr; -#if LLVM_VERSION_MAJOR >= 17 - llvm_unreachable("Unhandled"); - // T = F->getParamAttribute(i, Attribute::AttrKind::ElementType) - // .getValueAsType(); -#else - T = AT->getElementType()->getPointerElementType(); -#endif - IRBuilder<> EB(&NewF->getEntryBlock().front()); - Value *val = UndefValue::get(AT); - for (size_t j = 0; j < AT->getNumElements(); j++) { - auto AL = EB.CreateAlloca(T, 0, "stack_roots_v"); - val = EB.CreateInsertValue(val, AL, j); - } - arg->replaceAllUsesWith(val); - delete arg; - } - assert(curOffset == numRooting); - assert(sretCount == Types.size()); - - for (auto CI : callers) { - auto Attrs = CI->getAttributes(); - AttributeList NewAttrs; - IRBuilder<> B(CI); - IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front()); - SmallVector vals; - size_t nexti = 0; - Value *sret = nullptr; - if (sretTy) { - sret = EB.CreateAlloca(sretTy, 0, "stack_sret"); - vals.push_back(sret); - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + nexti, - Attribute::get(F->getContext(), Attribute::StructRet, sretTy)); - nexti++; - } - AllocaInst *roots = nullptr; - if (roots_AT) { - roots = EB.CreateAlloca(roots_AT, 0, "stack_roots_AT"); - vals.push_back(roots); - NewAttrs = NewAttrs.addAttribute( - - F->getContext(), AttributeList::FirstArgIndex + nexti, - "enzymejl_returnRoots"); - nexti++; - } - - for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) - NewAttrs = NewAttrs.addAttribute(F->getContext(), - AttributeList::FunctionIndex, attr); - - SmallVector sret_vals; - SmallVector sretv_vals; - for (size_t i = 0, end = CI->arg_size(); i < end; i++) { - if (rroots.count(i) || rroots_v.count(i)) { - continue; - } - if (enzyme_srets.count(i)) { - sret_vals.push_back(CI->getArgOperand(i)); - continue; - } - if (enzyme_srets_v.count(i)) { - sretv_vals.push_back(CI->getArgOperand(i)); - continue; - } - - for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + nexti, attr); - vals.push_back(CI->getArgOperand(i)); - nexti++; - } - - sretCount = 0; - if (!RT->isVoidTy()) { - sretCount++; - } - - std::function, int, Type *, - bool)> - copyNonJLValue = [&](Type *curType, Value *out, Value *in, - ArrayRef inds, int outPrefix, Type *ptrTy, - bool shouldZero) { - if (auto PT = dyn_cast(curType)) { - if (PT->getAddressSpace() == 10) { - if (shouldZero) { - SmallVector outinds; - auto c0 = ConstantInt::get(B.getInt64Ty(), 0); - outinds.push_back(c0); - if (outPrefix >= 0) - outinds.push_back( - ConstantInt::get(B.getInt32Ty(), outPrefix)); - for (auto v : inds) { - outinds.push_back(ConstantInt::get(B.getInt32Ty(), v)); - } - if (outinds.size() > 1) - out = B.CreateInBoundsGEP(sretTy, out, outinds); - B.CreateStore(getUndefinedValueForType(M, PT), out); - } - return; - } - } - - if (auto AT = dyn_cast(curType)) { - for (size_t i = 0; i < AT->getNumElements(); i++) { - SmallVector next(inds.begin(), inds.end()); - next.push_back(i); - copyNonJLValue(AT->getElementType(), out, in, next, outPrefix, - ptrTy, shouldZero); - } - return; - } - if (auto ST = dyn_cast(curType)) { - for (size_t i = 0; i < ST->getNumElements(); i++) { - SmallVector next(inds.begin(), inds.end()); - next.push_back(i); - copyNonJLValue(ST->getElementType(i), out, in, next, outPrefix, - ptrTy, shouldZero); - } - return; - } - - SmallVector ininds; - SmallVector outinds; - auto c0 = ConstantInt::get(B.getInt64Ty(), 0); - ininds.push_back(c0); - outinds.push_back(c0); - if (outPrefix >= 0) - outinds.push_back(ConstantInt::get(B.getInt32Ty(), outPrefix)); - for (auto v : inds) { - ininds.push_back(ConstantInt::get(B.getInt32Ty(), v)); - outinds.push_back(ConstantInt::get(B.getInt32Ty(), v)); - } - - if (outinds.size() > 1) - out = B.CreateInBoundsGEP(sretTy, out, outinds); - if (ininds.size() > 1) - in = B.CreateInBoundsGEP(ptrTy, in, ininds); - - auto ld = B.CreateLoad(curType, in); - B.CreateStore(ld, out); - }; - - for (Value *ptr : sret_vals) { - copyNonJLValue(Types[sretCount], sret, ptr, {}, ST ? sretCount : -1, - Types[sretCount], true); - sretCount++; - } - for (Value *ptr_v : sretv_vals) { - auto AT = cast(ptr_v->getType()); - for (size_t j = 0; j < AT->getNumElements(); j++) { - auto ptr = GradientUtils::extractMeta(B, ptr_v, j); - copyNonJLValue(Types[sretCount], sret, ptr, {}, - ST ? (sretCount + j) : -1, Types[sretCount], true); - } - sretCount += AT->getNumElements(); - } - - SmallVector Bundles; - for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I) - Bundles.emplace_back(CI->getOperandBundleAt(I)); - auto NC = B.CreateCall(NewF, vals, Bundles); - NC->setAttributes(NewAttrs); - - SmallVector, 4> TheMDs; - CI->getAllMetadataOtherThanDebugLoc(TheMDs); - SmallVector toCopy; - for (auto pair : TheMDs) - if (pair.first != LLVMContext::MD_range) { - toCopy.push_back(pair.first); - } - if (!toCopy.empty()) - NC->copyMetadata(*CI, toCopy); - NC->setDebugLoc(CI->getDebugLoc()); - - sretCount = 0; - if (!RT->isVoidTy()) { - auto gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, 0) : sret; - auto ld = B.CreateLoad(RT, gep); - if (auto MD = CI->getMetadata(LLVMContext::MD_range)) - ld->setMetadata(LLVMContext::MD_range, MD); - ld->takeName(CI); - CI->replaceAllUsesWith(ld); - sretCount++; - } - - for (auto ptr : sret_vals) { - if (!isa(ptr) && !isa(ptr)) { - auto gep = - ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret; - auto ld = B.CreateLoad(Types[sretCount], gep); - auto SI = B.CreateStore(ld, ptr); - PostCacheStore(SI, B); - } - sretCount++; - } - for (auto ptr_v : sretv_vals) { - auto AT = cast(ptr_v->getType()); - for (size_t j = 0; j < AT->getNumElements(); j++) { - auto gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount + j) - : sret; - auto ptr = GradientUtils::extractMeta(B, ptr_v, j); - if (!isa(ptr) && !isa(ptr)) { - auto ld = B.CreateLoad(Types[sretCount], gep); - auto SI = B.CreateStore(ld, ptr); - PostCacheStore(SI, B); - } - } - sretCount += AT->getNumElements(); - } - - NC->setCallingConv(CI->getCallingConv()); - CI->eraseFromParent(); - } - NewF->setAttributes(NewAttrs); - SmallVector, 1> MD; - F->getAllMetadata(MD); - for (auto pair : MD) - if (pair.first != LLVMContext::MD_dbg) - NewF->addMetadata(pair.first, *pair.second); - NewF->takeName(F); - NewF->setCallingConv(F->getCallingConv()); - F->eraseFromParent(); -} - -LLVMValueRef EnzymeBuildExtractValue(LLVMBuilderRef B, LLVMValueRef AggVal, - unsigned *Index, unsigned Size, - const char *Name) { - return wrap(unwrap(B)->CreateExtractValue( - unwrap(AggVal), ArrayRef(Index, Size), Name)); -} - -LLVMValueRef EnzymeBuildInsertValue(LLVMBuilderRef B, LLVMValueRef AggVal, - LLVMValueRef EltVal, unsigned *Index, - unsigned Size, const char *Name) { - return wrap(unwrap(B)->CreateInsertValue( - unwrap(AggVal), unwrap(EltVal), ArrayRef(Index, Size), Name)); -} -} diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h deleted file mode 100644 index f1f55f6178f1..000000000000 --- a/enzyme/Enzyme/CApi.h +++ /dev/null @@ -1,237 +0,0 @@ -//===- CApi.h - Enzyme API exported to C for external use -----------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file declares various utility functions of Enzyme for access via C -// -//===----------------------------------------------------------------------===// -#ifndef ENZYME_CAPI_H -#define ENZYME_CAPI_H - -#include "llvm-c/Core.h" -#include "llvm-c/DataTypes.h" -// #include "llvm-c/Initialization.h" -#include "llvm-c/Target.h" -#include - -#ifdef __cplusplus -extern "C" { -#endif - -struct EnzymeOpaqueTypeAnalysis; -typedef struct EnzymeOpaqueTypeAnalysis *EnzymeTypeAnalysisRef; - -struct EnzymeOpaqueLogic; -typedef struct EnzymeOpaqueLogic *EnzymeLogicRef; - -struct EnzymeOpaqueAugmentedReturn; -typedef struct EnzymeOpaqueAugmentedReturn *EnzymeAugmentedReturnPtr; - -struct EnzymeOpaqueTraceInterface; -typedef struct EnzymeOpaqueTraceInterface *EnzymeTraceInterfaceRef; - -struct IntList { - int64_t *data; - size_t size; -}; - -typedef enum { - DT_Anything = 0, - DT_Integer = 1, - DT_Pointer = 2, - DT_Half = 3, - DT_Float = 4, - DT_Double = 5, - DT_Unknown = 6, - DT_X86_FP80 = 7, - DT_BFloat16 = 8, -} CConcreteType; - -struct CDataPair { - struct IntList offsets; - CConcreteType datatype; -}; - -/* -struct CTypeTree { - struct CDataPair *data; - size_t size; -}; -*/ - -typedef enum { - VT_None = 0, - VT_Primal = 1, - VT_Shadow = 2, - VT_Both = VT_Primal | VT_Shadow, -} CValueType; - -struct EnzymeTypeTree; -typedef struct EnzymeTypeTree *CTypeTreeRef; -CTypeTreeRef EnzymeNewTypeTree(); -CTypeTreeRef EnzymeNewTypeTreeCT(CConcreteType, LLVMContextRef ctx); -CTypeTreeRef EnzymeNewTypeTreeTR(CTypeTreeRef); -void EnzymeFreeTypeTree(CTypeTreeRef CTT); -uint8_t EnzymeSetTypeTree(CTypeTreeRef dst, CTypeTreeRef src); -uint8_t EnzymeMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src); -void EnzymeTypeTreeOnlyEq(CTypeTreeRef dst, int64_t x); -void EnzymeTypeTreeData0Eq(CTypeTreeRef dst); -void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef dst, const char *datalayout, - int64_t offset, int64_t maxSize, - uint64_t addOffset); -const char *EnzymeTypeTreeToString(CTypeTreeRef src); -void EnzymeTypeTreeToStringFree(const char *cstr); - -void EnzymeSetCLBool(void *, uint8_t); -void EnzymeSetCLInteger(void *, int64_t); - -struct CFnTypeInfo { - /// Types of arguments, assumed of size len(Arguments) - CTypeTreeRef *Arguments; - - /// Type of return - CTypeTreeRef Return; - - /// The specific constant(s) known to represented by an argument, if constant - // map is [arg number] => list - struct IntList *KnownValues; -}; - -typedef enum { - DFT_OUT_DIFF = 0, // add differential to an output struct. Only for scalar - // values in ReverseMode variants. - DFT_DUP_ARG = 1, // duplicate the argument and store differential inside. - // For references, pointers, or integers in ReverseMode - // variants. For all types in ForwardMode variants. - DFT_CONSTANT = 2, // no differential. Usable everywhere. - DFT_DUP_NONEED = 3 // duplicate this argument and store differential inside, - // but don't need the forward. Same as DUP_ARG otherwise. -} CDIFFE_TYPE; - -typedef enum { BT_SCALAR = 0, BT_VECTOR = 1 } CBATCH_TYPE; - -typedef enum { - DEM_ForwardMode = 0, - DEM_ReverseModePrimal = 1, - DEM_ReverseModeGradient = 2, - DEM_ReverseModeCombined = 3, - DEM_ForwardModeSplit = 4, - DEM_ForwardModeError = 5 -} CDerivativeMode; - -typedef enum { - DEM_Trace = 0, - DEM_Condition = 1, -} CProbProgMode; - -typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/, - CTypeTreeRef * /*args*/, - struct IntList * /*knownValues*/, - size_t /*numArgs*/, LLVMValueRef, - void * /*TA*/); -EnzymeTypeAnalysisRef CreateTypeAnalysis(EnzymeLogicRef Log, - char **customRuleNames, - CustomRuleType *customRules, - size_t numRules); -void ClearTypeAnalysis(EnzymeTypeAnalysisRef); -void FreeTypeAnalysis(EnzymeTypeAnalysisRef); - -EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M); -EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface( - LLVMContextRef C, LLVMValueRef getTraceFunction, - LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction, - LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction, - LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction, - LLVMValueRef insertChoiceGradientFunction, - LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction, - LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction, - LLVMValueRef hasChoiceFunction); -EnzymeTraceInterfaceRef -CreateEnzymeDynamicTraceInterface(LLVMValueRef interface, LLVMValueRef F); -EnzymeLogicRef CreateEnzymeLogic(uint8_t PostOpt); -void ClearEnzymeLogic(EnzymeLogicRef); -void FreeEnzymeLogic(EnzymeLogicRef); - -void EnzymeExtractReturnInfo(EnzymeAugmentedReturnPtr ret, int64_t *data, - uint8_t *existed, size_t len); - -LLVMValueRef -EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret); -LLVMTypeRef EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret); - -class GradientUtils; -class DiffeGradientUtils; - -typedef LLVMValueRef (*CustomShadowAlloc)(LLVMBuilderRef, LLVMValueRef, - size_t /*numArgs*/, LLVMValueRef *, - GradientUtils *); -typedef LLVMValueRef (*CustomShadowFree)(LLVMBuilderRef, LLVMValueRef); - -void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, - CustomShadowFree FHandle); - -typedef uint8_t (*CustomFunctionForward)(LLVMBuilderRef, LLVMValueRef, - GradientUtils *, LLVMValueRef *, - LLVMValueRef *); - -typedef uint8_t (*CustomFunctionDiffUse)(LLVMValueRef, const GradientUtils *, - LLVMValueRef, uint8_t, CDerivativeMode, - uint8_t *); - -typedef uint8_t (*CustomAugmentedFunctionForward)(LLVMBuilderRef, LLVMValueRef, - GradientUtils *, - LLVMValueRef *, - LLVMValueRef *, - LLVMValueRef *); - -typedef void (*CustomFunctionReverse)(LLVMBuilderRef, LLVMValueRef, - DiffeGradientUtils *, LLVMValueRef); - -LLVMValueRef EnzymeCreateForwardDiff( - EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, - LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, - size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, - CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity, - unsigned width, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, - uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, - size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented); - -LLVMValueRef EnzymeCreatePrimalAndGradient( - EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, - LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, - size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, - uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity, - unsigned width, uint8_t freeMemory, LLVMTypeRef additionalArg, - uint8_t forceAnonymousTape, CFnTypeInfo typeInfo, - uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, - size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented, - uint8_t AtomicAdd); - -void EnzymeRegisterCallHandler(const char *Name, - CustomAugmentedFunctionForward FwdHandle, - CustomFunctionReverse RevHandle); - -LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils, - LLVMValueRef val); - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 629cecdc578f..647c8d6b5b01 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -44,12 +44,11 @@ set(LLVM_LINK_COMPONENTS Demangle) file(GLOB ENZYME_SRC CONFIGURE_DEPENDS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp" ) +list(REMOVE_ITEM ENZYME_SRC "api.cpp") list(REMOVE_ITEM ENZYME_SRC "eopt.cpp") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) -list(APPEND ENZYME_SRC TypeAnalysis/TypeTree.cpp TypeAnalysis/TypeAnalysis.cpp TypeAnalysis/TypeAnalysisPrinter.cpp TypeAnalysis/RustDebugInfo.cpp) - if (ENZYME_ENABLE_PLUGINS) # on windows `PLUGIN_TOOL` doesn't link against LLVM.dll if ((WIN32 OR CYGWIN) AND LLVM_LINK_LLVM_DYLIB) diff --git a/enzyme/Enzyme/CacheUtility.cpp b/enzyme/Enzyme/CacheUtility.cpp deleted file mode 100644 index 052284b9d54b..000000000000 --- a/enzyme/Enzyme/CacheUtility.cpp +++ /dev/null @@ -1,1627 +0,0 @@ -//===- CacheUtility.cpp - Caching values in the forward pass for later use -//-===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file defines a base helper class CacheUtility that manages the cache -// of values from the forward pass for later use. -// -//===----------------------------------------------------------------------===// - -#include "CacheUtility.h" -#include "FunctionUtils.h" - -using namespace llvm; - -/// Pack 8 bools together in a single byte -extern "C" { -llvm::cl::opt - EfficientBoolCache("enzyme-smallbool", cl::init(false), cl::Hidden, - cl::desc("Place 8 bools together in a single byte")); - -llvm::cl::opt EnzymeZeroCache("enzyme-zero-cache", cl::init(false), - cl::Hidden, - cl::desc("Zero initialize the cache")); - -llvm::cl::opt - EnzymePrintPerf("enzyme-print-perf", cl::init(false), cl::Hidden, - cl::desc("Enable Enzyme to print performance info")); - -llvm::cl::opt EfficientMaxCache( - "enzyme-max-cache", cl::init(false), cl::Hidden, - cl::desc( - "Avoid reallocs when possible by potentially overallocating cache")); -} - -CacheUtility::~CacheUtility() {} - -/// Erase this instruction both from LLVM modules and any local data-structures -void CacheUtility::erase(Instruction *I) { - assert(I); - - if (auto found = findInMap(scopeMap, (Value *)I)) { - scopeFrees.erase(found->first); - scopeAllocs.erase(found->first); - scopeInstructions.erase(found->first); - } - if (auto AI = dyn_cast(I)) { - scopeFrees.erase(AI); - scopeAllocs.erase(AI); - scopeInstructions.erase(AI); - } - scopeMap.erase(I); - SE.eraseValueFromMap(I); - - if (!I->use_empty()) { - std::string str; - raw_string_ostream ss(str); - ss << "Erased value with a use:\n"; - ss << *newFunc->getParent() << "\n"; - ss << *newFunc << "\n"; - ss << *I << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(I), ErrorType::InternalError, - nullptr, nullptr, nullptr); - } else { - EmitFailure("GetIndexError", I->getDebugLoc(), I, ss.str()); - } - I->replaceAllUsesWith(UndefValue::get(I->getType())); - } - assert(I->use_empty()); - I->eraseFromParent(); -} - -/// Replace this instruction both in LLVM modules and any local data-structures -void CacheUtility::replaceAWithB(Value *A, Value *B, bool storeInCache) { - auto found = scopeMap.find(A); - if (found != scopeMap.end()) { - insert_or_assign2(scopeMap, B, found->second); - - llvm::AllocaInst *cache = found->second.first; - if (storeInCache) { - assert(isa(B)); - auto stfound = scopeInstructions.find(cache); - if (stfound != scopeInstructions.end()) { - SmallVector tmpInstructions(stfound->second.begin(), - stfound->second.end()); - scopeInstructions.erase(stfound); - for (auto st : tmpInstructions) - cast(&*st)->eraseFromParent(); - MDNode *TBAA = nullptr; - if (auto I = dyn_cast(A)) - TBAA = I->getMetadata(LLVMContext::MD_tbaa); - storeInstructionInCache(found->second.second, cast(B), - cache, TBAA); - } - } - - scopeMap.erase(A); - } - A->replaceAllUsesWith(B); -} - -// Create a new canonical induction variable of Type Ty for Loop L -// Return the variable and the increment instruction -std::pair -InsertNewCanonicalIV(Loop *L, Type *Ty, const llvm::Twine &Name) { - assert(L); - assert(Ty); - - BasicBlock *Header = L->getHeader(); - assert(Header); - IRBuilder<> B(&Header->front()); - PHINode *CanonicalIV = B.CreatePHI(Ty, 1, Name); - - B.SetInsertPoint(Header->getFirstNonPHIOrDbg()); - Instruction *Inc = cast( - B.CreateAdd(CanonicalIV, ConstantInt::get(Ty, 1), Name + ".next", - /*NUW*/ true, /*NSW*/ true)); - - for (BasicBlock *Pred : predecessors(Header)) { - assert(Pred); - if (L->contains(Pred)) { - CanonicalIV->addIncoming(Inc, Pred); - } else { - CanonicalIV->addIncoming(ConstantInt::get(Ty, 0), Pred); - } - } - assert(L->getCanonicalInductionVariable() == CanonicalIV); - return std::pair(CanonicalIV, Inc); -} - -// Create a new canonical induction variable of Type Ty for Loop L -// Return the variable and the increment instruction -std::pair FindCanonicalIV(Loop *L, Type *Ty) { - assert(L); - assert(Ty); - - BasicBlock *Header = L->getHeader(); - assert(Header); - for (BasicBlock::iterator II = Header->begin(); isa(II); ++II) { - PHINode *PN = cast(II); - if (PN->getType() != Ty) - continue; - - Instruction *Inc = nullptr; - bool legal = true; - for (BasicBlock *Pred : predecessors(Header)) { - assert(Pred); - if (L->contains(Pred)) { - auto Inc2 = - dyn_cast(PN->getIncomingValueForBlock(Pred)); - if (!Inc2 || Inc2->getOpcode() != Instruction::Add || - Inc2->getOperand(0) != PN) { - legal = false; - break; - } - auto CI = dyn_cast(Inc2->getOperand(1)); - if (!CI || !CI->isOne()) { - legal = false; - break; - } - if (Inc) { - if (Inc2 != Inc) { - legal = false; - break; - } - } else - Inc = Inc2; - } else { - auto CI = dyn_cast(PN->getIncomingValueForBlock(Pred)); - if (!CI || !CI->isZero()) { - legal = false; - break; - } - } - } - if (!legal) - continue; - if (!Inc) - continue; - if (Inc != getFirstNonPHIOrDbg(Header)) - Inc->moveBefore(getFirstNonPHIOrDbg(Header)); - return std::make_pair(PN, Inc); - } - llvm::errs() << *Header << "\n"; - assert(0 && "Could not find canonical IV"); - return std::pair(nullptr, nullptr); -} - -// Attempt to rewrite all phinode's in the loop in terms of the -// induction variable -void RemoveRedundantIVs( - BasicBlock *Header, PHINode *CanonicalIV, Instruction *Increment, - MustExitScalarEvolution &SE, - llvm::function_ref replacer, - llvm::function_ref eraser) { - assert(Header); - assert(CanonicalIV); - SmallVector IVsToRemove; - - auto CanonicalSCEV = SE.getSCEV(CanonicalIV); - - for (BasicBlock::iterator II = Header->begin(); isa(II);) { - PHINode *PN = cast(II); - ++II; - if (PN == CanonicalIV) - continue; - if (!SE.isSCEVable(PN->getType())) - continue; - const SCEV *S = SE.getSCEV(PN); - if (SE.getCouldNotCompute() == S || isa(S)) - continue; - // we may expand code for phi where not legal (computing with - // subloop expressions). Check that this isn't the case - if (!SE.dominates(S, Header)) - continue; - - if (S == CanonicalSCEV) { - replacer(PN, CanonicalIV); - eraser(PN); - continue; - } - - IRBuilder<> B(PN); - auto Tmp = B.CreatePHI(PN->getType(), 0); - for (auto Pred : predecessors(Header)) - Tmp->addIncoming(UndefValue::get(Tmp->getType()), Pred); - replacer(PN, Tmp); - eraser(PN); - - // This scope is necessary to ensure scevexpander cleans up before we erase - // things - SCEVExpander Exp(SE, Header->getParent()->getParent()->getDataLayout(), - "enzyme"); - - // We place that at first non phi as it may produce a non-phi instruction - // and must thus be expanded after all phi's - Value *NewIV = - Exp.expandCodeFor(S, Tmp->getType(), Header->getFirstNonPHI()); - - // Explicity preserve wrap behavior from original iv. This is necessary - // until this PR in llvm is merged: - // https://github.com/llvm/llvm-project/pull/78199 - if (auto addrec = dyn_cast(S)) { - if (addrec->getLoop()->getHeader() == Header) { - if (auto add_or_mul = dyn_cast(NewIV)) { - if (addrec->getNoWrapFlags(llvm::SCEV::FlagNUW)) - add_or_mul->setHasNoUnsignedWrap(true); - if (addrec->getNoWrapFlags(llvm::SCEV::FlagNSW)) - add_or_mul->setHasNoSignedWrap(true); - } - } - } - replacer(Tmp, NewIV); - eraser(Tmp); - } - - // Replace existing increments with canonical Increment - Increment->moveAfter(CanonicalIV->getParent()->getFirstNonPHI()); - SmallVector toErase; - for (auto use : CanonicalIV->users()) { - auto BO = dyn_cast(use); - if (BO == nullptr) - continue; - if (BO->getOpcode() != BinaryOperator::Add) - continue; - if (use == Increment) - continue; - - Value *toadd = nullptr; - if (BO->getOperand(0) == CanonicalIV) { - toadd = BO->getOperand(1); - } else { - assert(BO->getOperand(1) == CanonicalIV); - toadd = BO->getOperand(0); - } - if (auto CI = dyn_cast(toadd)) { - if (!CI->isOne()) - continue; - BO->replaceAllUsesWith(Increment); - toErase.push_back(BO); - } else { - continue; - } - } - for (auto BO : toErase) - eraser(BO); -} - -void CanonicalizeLatches(const Loop *L, BasicBlock *Header, - BasicBlock *Preheader, PHINode *CanonicalIV, - MustExitScalarEvolution &SE, CacheUtility &gutils, - Instruction *Increment, - ArrayRef latches) { - // Attempt to explicitly rewrite the latch - if (latches.size() == 1 && isa(latches[0]->getTerminator()) && - cast(latches[0]->getTerminator())->isConditional()) - for (auto use : CanonicalIV->users()) { - if (auto cmp = dyn_cast(use)) { - if (cast(latches[0]->getTerminator())->getCondition() != - cmp) - continue; - // Force i to be on LHS - if (cmp->getOperand(0) != CanonicalIV) { - // Below also swaps predicate correctly - cmp->swapOperands(); - } - assert(cmp->getOperand(0) == CanonicalIV); - - auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L); - if (cmp->isUnsigned() || - (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv))) { - - // valid replacements (since unsigned comparison and i starts at 0 - // counting up) - - // * i < n => i != n, valid since first time i >= n occurs at i == n - if (cmp->getPredicate() == ICmpInst::ICMP_ULT || - cmp->getPredicate() == ICmpInst::ICMP_SLT) { - cmp->setPredicate(ICmpInst::ICMP_NE); - goto cend; - } - - // * i <= n => i != n+1, valid since first time i > n occurs at i == - // n+1 [ which we assert is in bitrange as not infinite loop ] - if (cmp->getPredicate() == ICmpInst::ICMP_ULE || - cmp->getPredicate() == ICmpInst::ICMP_SLE) { - IRBuilder<> builder(Preheader->getTerminator()); - if (auto inst = dyn_cast(cmp->getOperand(1))) { - builder.SetInsertPoint(inst->getNextNode()); - } - cmp->setOperand( - 1, - builder.CreateNUWAdd( - cmp->getOperand(1), - ConstantInt::get(cmp->getOperand(1)->getType(), 1, false))); - cmp->setPredicate(ICmpInst::ICMP_NE); - goto cend; - } - - // * i >= n => i == n, valid since first time i >= n occurs at i == n - if (cmp->getPredicate() == ICmpInst::ICMP_UGE || - cmp->getPredicate() == ICmpInst::ICMP_SGE) { - cmp->setPredicate(ICmpInst::ICMP_EQ); - goto cend; - } - - // * i > n => i == n+1, valid since first time i > n occurs at i == - // n+1 [ which we assert is in bitrange as not infinite loop ] - if (cmp->getPredicate() == ICmpInst::ICMP_UGT || - cmp->getPredicate() == ICmpInst::ICMP_SGT) { - IRBuilder<> builder(Preheader->getTerminator()); - if (auto inst = dyn_cast(cmp->getOperand(1))) { - builder.SetInsertPoint(inst->getNextNode()); - } - cmp->setOperand( - 1, - builder.CreateNUWAdd( - cmp->getOperand(1), - ConstantInt::get(cmp->getOperand(1)->getType(), 1, false))); - cmp->setPredicate(ICmpInst::ICMP_EQ); - goto cend; - } - } - cend:; - if (cmp->getPredicate() == ICmpInst::ICMP_NE) { - } - } - } - - // Replace previous increment usage with new increment value - if (Increment) { - Increment->moveAfter(CanonicalIV->getParent()->getFirstNonPHI()); - - if (latches.size() == 1 && isa(latches[0]->getTerminator()) && - cast(latches[0]->getTerminator())->isConditional()) - for (auto use : Increment->users()) { - if (auto cmp = dyn_cast(use)) { - if (cast(latches[0]->getTerminator())->getCondition() != - cmp) - continue; - - // Force i+1 to be on LHS - if (cmp->getOperand(0) != Increment) { - // Below also swaps predicate correctly - cmp->swapOperands(); - } - assert(cmp->getOperand(0) == Increment); - - auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L); - if (cmp->isUnsigned() || - (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv))) { - - // valid replacements (since unsigned comparison and i starts at 0 - // counting up) - - // * i+1 < n => i+1 != n, valid since first time i+1 >= n occurs at - // i+1 == n - if (cmp->getPredicate() == ICmpInst::ICMP_ULT || - cmp->getPredicate() == ICmpInst::ICMP_SLT) { - cmp->setPredicate(ICmpInst::ICMP_NE); - continue; - } - - // * i+1 <= n => i != n, valid since first time i+1 > n occurs at - // i+1 == n+1 => i == n - if (cmp->getPredicate() == ICmpInst::ICMP_ULE || - cmp->getPredicate() == ICmpInst::ICMP_SLE) { - cmp->setOperand(0, CanonicalIV); - cmp->setPredicate(ICmpInst::ICMP_NE); - continue; - } - - // * i+1 >= n => i+1 == n, valid since first time i+1 >= n occurs at - // i+1 == n - if (cmp->getPredicate() == ICmpInst::ICMP_UGE || - cmp->getPredicate() == ICmpInst::ICMP_SGE) { - cmp->setPredicate(ICmpInst::ICMP_EQ); - continue; - } - - // * i+1 > n => i == n, valid since first time i+1 > n occurs at i+1 - // == n+1 => i == n - if (cmp->getPredicate() == ICmpInst::ICMP_UGT || - cmp->getPredicate() == ICmpInst::ICMP_SGT) { - cmp->setOperand(0, CanonicalIV); - cmp->setPredicate(ICmpInst::ICMP_EQ); - continue; - } - } - } - } - } -} - -llvm::AllocaInst *CacheUtility::getDynamicLoopLimit(llvm::Loop *L, - bool ReverseLimit) { - assert(L); - assert(loopContexts.find(L) != loopContexts.end()); - auto &found = loopContexts[L]; - assert(found.dynamic); - if (found.trueLimit) - return cast(&*found.trueLimit); - - LimitContext lctx(ReverseLimit, - ReverseLimit ? found.preheader : &newFunc->getEntryBlock()); - AllocaInst *LimitVar = - createCacheForScope(lctx, found.var->getType(), "loopLimit", - /*shouldfree*/ true); - - for (auto ExitBlock : found.exitBlocks) { - IRBuilder<> B(&ExitBlock->front()); - auto Limit = B.CreatePHI(found.var->getType(), 1); - - for (BasicBlock *Pred : predecessors(ExitBlock)) { - if (L->contains(Pred)) { - Limit->addIncoming(found.var, Pred); - } else { - Limit->addIncoming(UndefValue::get(found.var->getType()), Pred); - } - } - - storeInstructionInCache(lctx, Limit, LimitVar); - } - found.trueLimit = LimitVar; - return LimitVar; -} - -bool CacheUtility::getContext(BasicBlock *BB, LoopContext &loopContext, - bool ReverseLimit) { - assert(BB->getParent() == newFunc); - Loop *L = LI.getLoopFor(BB); - - // Not inside a loop - if (L == nullptr) - return false; - - // Previously handled this loop - if (auto found = findInMap(loopContexts, L)) { - loopContext = *found; - return true; - } - - // Need to canonicalize - loopContexts[L].parent = L->getParentLoop(); - - loopContexts[L].header = L->getHeader(); - assert(loopContexts[L].header && "loop must have header"); - - loopContexts[L].preheader = L->getLoopPreheader(); - if (!L->getLoopPreheader()) { - llvm::errs() << "fn: " << *L->getHeader()->getParent() << "\n"; - llvm::errs() << "L: " << *L << "\n"; - } - assert(loopContexts[L].preheader && "loop must have preheader"); - getExitBlocks(L, loopContexts[L].exitBlocks); - - loopContexts[L].offset = nullptr; - loopContexts[L].allocLimit = nullptr; - // A precisely matching canonical IV shouldve been run during preprocessing. - auto pair = FindCanonicalIV(L, Type::getInt64Ty(BB->getContext())); - PHINode *CanonicalIV = pair.first; - auto incVar = pair.second; - assert(CanonicalIV); - loopContexts[L].var = CanonicalIV; - loopContexts[L].incvar = incVar; - CanonicalizeLatches(L, loopContexts[L].header, loopContexts[L].preheader, - CanonicalIV, SE, *this, incVar, - getLatches(L, loopContexts[L].exitBlocks)); - loopContexts[L].antivaralloc = - IRBuilder<>(inversionAllocs) - .CreateAlloca(CanonicalIV->getType(), nullptr, - CanonicalIV->getName() + "'ac"); - loopContexts[L].antivaralloc->setAlignment( - Align(cast(CanonicalIV->getType())->getBitWidth() / 8)); - - const SCEV *Limit = nullptr; - const SCEV *MaxIterations = nullptr; - { - const SCEV *MayExitMaxBECount = nullptr; - - SmallVector ExitingBlocks; - L->getExitingBlocks(ExitingBlocks); - - // Remove all exiting blocks that are guaranteed - // to result in unreachable - for (auto &ExitingBlock : ExitingBlocks) { - BasicBlock *Exit = nullptr; - for (auto *SBB : successors(ExitingBlock)) { - if (!L->contains(SBB)) { - if (SE.GuaranteedUnreachable.count(SBB)) - continue; - Exit = SBB; - break; - } - } - if (!Exit) - ExitingBlock = nullptr; - } - ExitingBlocks.erase( - std::remove(ExitingBlocks.begin(), ExitingBlocks.end(), nullptr), - ExitingBlocks.end()); - - // Compute the exit in the scenarios where an unreachable - // is not hit - for (BasicBlock *ExitingBlock : ExitingBlocks) { - assert(L->contains(ExitingBlock)); - - ScalarEvolution::ExitLimit EL = - SE.computeExitLimit(L, ExitingBlock, /*AllowPredicates*/ true); - - bool seenHeaders = false; - SmallPtrSet Seen; - std::deque Todo = {ExitingBlock}; - while (Todo.size()) { - auto cur = Todo.front(); - Todo.pop_front(); - if (Seen.count(cur)) - continue; - if (!L->contains(cur)) - continue; - if (cur == loopContexts[L].header) { - seenHeaders = true; - break; - } - for (auto S : successors(cur)) { - Todo.push_back(S); - } - } - if (seenHeaders) { - if (MaxIterations == nullptr || - MaxIterations == SE.getCouldNotCompute()) { - MaxIterations = EL.ExactNotTaken; - } - if (MaxIterations != SE.getCouldNotCompute()) { - if (EL.ExactNotTaken != SE.getCouldNotCompute()) { - MaxIterations = - SE.getUMaxFromMismatchedTypes(MaxIterations, EL.ExactNotTaken); - } - } - - if (MayExitMaxBECount == nullptr || - EL.ExactNotTaken == SE.getCouldNotCompute()) - MayExitMaxBECount = EL.ExactNotTaken; - - if (EL.ExactNotTaken != MayExitMaxBECount) { - MayExitMaxBECount = SE.getCouldNotCompute(); - } - } - } - if (MayExitMaxBECount == nullptr) { - MayExitMaxBECount = SE.getCouldNotCompute(); - } - if (MaxIterations == nullptr) { - MaxIterations = SE.getCouldNotCompute(); - } - Limit = MayExitMaxBECount; - } - assert(Limit); - Value *LimitVar = nullptr; - - if (SE.getCouldNotCompute() != Limit) { - - if (CanonicalIV == nullptr) { - report_fatal_error("Couldn't get canonical IV."); - } - - SmallPtrSet PotentialMins; - SmallVector Todo = {Limit}; - while (Todo.size()) { - auto S = Todo.back(); - Todo.pop_back(); - if (auto SA = dyn_cast(S)) { - for (auto op : SA->operands()) - Todo.push_back(op); - } else if (auto SA = dyn_cast(S)) { - for (auto op : SA->operands()) - Todo.push_back(op); - } else if (auto SA = dyn_cast(S)) { - for (auto op : SA->operands()) - Todo.push_back(op); - } else - PotentialMins.insert(S); - } - for (auto op : PotentialMins) { - auto SM = dyn_cast(op); - if (!SM) - continue; - if (SM->getNumOperands() != 2) - continue; - for (int i = 0; i < 2; i++) - if (auto C = dyn_cast(SM->getOperand(i))) { - // is minus 1 -#if LLVM_VERSION_MAJOR > 16 - if (C->getAPInt().isAllOnes()) -#else - if (C->getAPInt().isAllOnesValue()) -#endif - { - const SCEV *prev = SM->getOperand(1 - i); - while (true) { - if (auto ext = dyn_cast(prev)) { - prev = ext->getOperand(); - continue; - } - if (auto ext = dyn_cast(prev)) { - prev = ext->getOperand(); - continue; - } - break; - } - if (auto V = dyn_cast(prev)) { - if (auto omp_lb_post = dyn_cast(V->getValue())) { - auto AI = - dyn_cast(omp_lb_post->getPointerOperand()); - if (AI) { - for (auto u : AI->users()) { - CallInst *call = dyn_cast(u); - if (!call) - continue; - Function *F = call->getCalledFunction(); - if (!F) - continue; - if (F->getName() == "__kmpc_for_static_init_4" || - F->getName() == "__kmpc_for_static_init_4u" || - F->getName() == "__kmpc_for_static_init_8" || - F->getName() == "__kmpc_for_static_init_8u") { - Value *lb = nullptr; - for (auto u : call->getArgOperand(4)->users()) { - if (auto si = dyn_cast(u)) { - lb = si->getValueOperand(); - break; - } - } - assert(lb); - Value *ub = nullptr; - for (auto u : call->getArgOperand(5)->users()) { - if (auto si = dyn_cast(u)) { - ub = si->getValueOperand(); - break; - } - } - assert(ub); - IRBuilder<> post(omp_lb_post->getNextNode()); - loopContexts[L].allocLimit = post.CreateZExtOrTrunc( - post.CreateSub(ub, lb), CanonicalIV->getType()); - loopContexts[L].offset = post.CreateZExtOrTrunc( - post.CreateSub(omp_lb_post, lb, "", true, true), - CanonicalIV->getType()); - goto endOMP; - } - } - } - } - } - } - } - } - endOMP:; - - if (Limit->getType() != CanonicalIV->getType()) - Limit = SE.getZeroExtendExpr(Limit, CanonicalIV->getType()); - - SCEVExpander Exp(SE, BB->getParent()->getParent()->getDataLayout(), - "enzyme"); - LimitVar = Exp.expandCodeFor(Limit, CanonicalIV->getType(), - loopContexts[L].preheader->getTerminator()); - loopContexts[L].dynamic = false; - loopContexts[L].maxLimit = LimitVar; - } else { - // TODO if assumeDynamicLoopOfSizeOne(L), only lazily allocate the scope - // cache - DebugLoc loc = L->getHeader()->begin()->getDebugLoc(); - for (auto &I : *L->getHeader()) { - if (loc) - break; - loc = I.getDebugLoc(); - } - EmitWarning("NoLimit", loc, L->getHeader(), - "SE could not compute loop limit of ", - L->getHeader()->getName(), " of ", - L->getHeader()->getParent()->getName(), "lim: ", *Limit, - " maxlim: ", *MaxIterations); - - loopContexts[L].dynamic = true; - loopContexts[L].maxLimit = nullptr; - - if (assumeDynamicLoopOfSizeOne(L)) { - LimitVar = nullptr; - } else { - LimitVar = getDynamicLoopLimit(L, ReverseLimit); - } - } - loopContexts[L].trueLimit = LimitVar; - if (EfficientMaxCache && loopContexts[L].dynamic && - SE.getCouldNotCompute() != MaxIterations) { - if (MaxIterations->getType() != CanonicalIV->getType()) - MaxIterations = - SE.getZeroExtendExpr(MaxIterations, CanonicalIV->getType()); - - SCEVExpander Exp(SE, BB->getParent()->getParent()->getDataLayout(), - "enzyme"); - - loopContexts[L].maxLimit = - Exp.expandCodeFor(MaxIterations, CanonicalIV->getType(), - loopContexts[L].preheader->getTerminator()); - } - loopContext = loopContexts.find(L)->second; - return true; -} - -/// Caching mechanism: creates a cache of type T in a scope given by ctx -/// (where if ctx is in a loop there will be a corresponding number of slots) -AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T, - StringRef name, bool shouldFree, - bool allocateInternal, - Value *extraSize) { - assert(ctx.Block); - assert(T); - - auto sublimits = - getSubLimits(/*inForwardPass*/ true, nullptr, ctx, extraSize); - - auto i64 = Type::getInt64Ty(T->getContext()); - - // List of types stored in the cache for each Loop-Chunk - // This is stored from innner-most chunk to outermost - // Thus it begins with the underlying type, and adds pointers - // to the previous type. - SmallVector types = {T}; - SmallVector malloctypes; - bool isi1 = T->isIntegerTy() && cast(T)->getBitWidth() == 1; - if (EfficientBoolCache && isi1 && sublimits.size() != 0) - types[0] = Type::getInt8Ty(T->getContext()); - for (size_t i = 0; i < sublimits.size(); ++i) { - Type *allocType; - { - BasicBlock *BB = - BasicBlock::Create(newFunc->getContext(), "entry", newFunc); - IRBuilder<> B(BB); - auto P = B.CreatePHI(i64, 1); - - CallInst *malloccall; - Instruction *Zero; - allocType = cast(CreateAllocation(B, types.back(), P, - "tmpfortypecalc", - &malloccall, &Zero) - ->getType()); - malloctypes.push_back(cast(malloccall->getType())); - for (auto &I : make_early_inc_range(reverse(*BB))) - I.eraseFromParent(); - - BB->eraseFromParent(); - } - types.push_back(allocType); - } - - // Allocate the outermost type on the stack - IRBuilder<> entryBuilder(inversionAllocs); - entryBuilder.setFastMathFlags(getFast()); - AllocaInst *alloc = - entryBuilder.CreateAlloca(types.back(), nullptr, name + "_cache"); - { - ConstantInt *byteSizeOfType = ConstantInt::get( - i64, newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits( - types.back()) / - 8); - unsigned align = - getCacheAlignment((unsigned)byteSizeOfType->getZExtValue()); - alloc->setAlignment(Align(align)); - } - if (sublimits.size() == 0) { - auto val = getUndefinedValueForType(*newFunc->getParent(), types.back()); - if (!isa(val)) - scopeInstructions[alloc].push_back(entryBuilder.CreateStore(val, alloc)); - } - - Value *storeInto = alloc; - - // Iterating from outermost chunk to innermost chunk - // Allocate and store the requisite memory if needed - // and lookup the next level pointer of the cache - for (int i = sublimits.size() - 1; i >= 0; i--) { - const auto &containedloops = sublimits[i].second; - - Type *myType = types[i]; - - ConstantInt *byteSizeOfType = ConstantInt::get( - Type::getInt64Ty(T->getContext()), - newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(myType) / - 8); - - unsigned bsize = (unsigned)byteSizeOfType->getZExtValue(); - unsigned alignSize = getCacheAlignment(bsize); - - CallInst *malloccall = nullptr; - - // Allocate and store the required memory - if (allocateInternal) { - - IRBuilder<> allocationBuilder( - &containedloops.back().first.preheader->back()); - - Value *size = sublimits[i].first; - if (EfficientBoolCache && isi1 && i == 0) { - size = allocationBuilder.CreateLShr( - allocationBuilder.CreateAdd( - size, ConstantInt::get(Type::getInt64Ty(T->getContext()), 7), - "", true), - ConstantInt::get(Type::getInt64Ty(T->getContext()), 3)); - } - if (extraSize && i == 0) { - ValueToValueMapTy available; - for (auto &sl : sublimits) { - for (auto &cl : sl.second) { - if (cl.first.var) - available[cl.first.var] = cl.first.var; - } - } - Value *es = unwrapM(extraSize, allocationBuilder, available, - UnwrapMode::AttemptFullUnwrapWithLookup); - assert(es); - size = allocationBuilder.CreateMul(size, es, "", /*NUW*/ true, - /*NSW*/ true); - } - - StoreInst *storealloc = nullptr; - // Statically allocate memory for all iterations if possible - if (sublimits[i].second.back().first.maxLimit) { - Instruction *ZeroInst = nullptr; - Value *firstallocation = CreateAllocation( - allocationBuilder, myType, size, name + "_malloccache", &malloccall, - /*ZeroMem*/ EnzymeZeroCache ? &ZeroInst : nullptr); - - scopeInstructions[alloc].push_back(malloccall); - if (firstallocation != malloccall) - scopeInstructions[alloc].push_back( - cast(firstallocation)); - - for (auto &actx : sublimits[i].second) { - if (actx.first.offset) { - malloccall->setMetadata("enzyme_ompfor", - MDNode::get(malloccall->getContext(), {})); - break; - } - } - - if (ZeroInst) { - if (ZeroInst->getOperand(0) != malloccall) { - scopeInstructions[alloc].push_back( - cast(ZeroInst->getOperand(0))); - } - scopeInstructions[alloc].push_back(ZeroInst); - } - storealloc = allocationBuilder.CreateStore(firstallocation, storeInto); - - scopeAllocs[alloc].push_back(malloccall); - - // Mark the store as invariant since the allocation is static and - // will not be changed - if (CachePointerInvariantGroups.find(std::make_pair( - (Value *)alloc, i)) == CachePointerInvariantGroups.end()) { - MDNode *invgroup = MDNode::getDistinct(alloc->getContext(), {}); - CachePointerInvariantGroups[std::make_pair((Value *)alloc, i)] = - invgroup; - } - storealloc->setMetadata( - LLVMContext::MD_invariant_group, - CachePointerInvariantGroups[std::make_pair((Value *)alloc, i)]); - scopeInstructions[alloc].push_back(storealloc); - for (auto post : PostCacheStore(storealloc, allocationBuilder)) { - scopeInstructions[alloc].push_back(post); - } - } else { - llvm::PointerType *allocType = cast(types[i + 1]); - llvm::PointerType *mallocType = malloctypes[i]; - - // Reallocate memory dynamically as a fallback - // TODO change this to a power-of-two allocation strategy - - auto zerostore = allocationBuilder.CreateStore( - getUndefinedValueForType(*newFunc->getParent(), allocType, - /*forceZero*/ true), - storeInto); - scopeInstructions[alloc].push_back(zerostore); - - IRBuilder<> build(containedloops.back().first.incvar->getNextNode()); - Value *allocation = build.CreateLoad(allocType, storeInto); - - if (allocation->getType() != mallocType) { - auto I = - cast(build.CreateBitCast(allocation, mallocType)); - scopeInstructions[alloc].push_back(I); - allocation = I; - } - - CallInst *realloccall = nullptr; - auto reallocation = CreateReAllocation( - build, allocation, myType, containedloops.back().first.incvar, size, - name + "_realloccache", &realloccall, EnzymeZeroCache && i == 0); - - scopeInstructions[alloc].push_back(cast(reallocation)); - - if (reallocation->getType() != allocType) { - auto I = - cast(build.CreateBitCast(reallocation, allocType)); - scopeInstructions[alloc].push_back(I); - reallocation = I; - } - - scopeAllocs[alloc].push_back(realloccall); - - storealloc = build.CreateStore(reallocation, storeInto); - // Unlike the static case we can not mark the memory as invariant - // since we are reloading/storing based off the number of loop - // iterations - scopeInstructions[alloc].push_back(storealloc); - for (auto post : PostCacheStore(storealloc, build)) { - scopeInstructions[alloc].push_back(post); - } - } - - // Regardless of how allocated (dynamic vs static), mark it - // as having the requisite alignment - storealloc->setAlignment(Align(alignSize)); - } - - // Free the memory, if requested - if (shouldFree) { - if (CachePointerInvariantGroups.find(std::make_pair((Value *)alloc, i)) == - CachePointerInvariantGroups.end()) { - MDNode *invgroup = MDNode::getDistinct(alloc->getContext(), {}); - CachePointerInvariantGroups[std::make_pair((Value *)alloc, i)] = - invgroup; - } - auto freecall = freeCache( - containedloops.back().first.preheader, sublimits, i, alloc, - byteSizeOfType, storeInto, - CachePointerInvariantGroups[std::make_pair((Value *)alloc, i)]); - if (freecall && malloccall) { - auto ident = MDNode::getDistinct(malloccall->getContext(), {}); - malloccall->setMetadata("enzyme_cache_alloc", - MDNode::get(malloccall->getContext(), {ident})); - freecall->setMetadata("enzyme_cache_free", - MDNode::get(freecall->getContext(), {ident})); - } - } - - // If we are not the final iteration, lookup the next pointer by indexing - // into the relevant location of the current chunk allocation - if (i != 0) { - IRBuilder<> v(&sublimits[i - 1].second.back().first.preheader->back()); - - Value *idx = computeIndexOfChunk( - /*inForwardPass*/ true, v, containedloops, - /*available*/ ValueToValueMapTy()); - - storeInto = v.CreateLoad(types[i + 1], storeInto); - cast(storeInto)->setAlignment(Align(alignSize)); - storeInto = v.CreateGEP(types[i], storeInto, idx); - cast(storeInto)->setIsInBounds(true); - } - } - return alloc; -} - -Value *CacheUtility::computeIndexOfChunk( - bool inForwardPass, IRBuilder<> &v, - ArrayRef> containedloops, - const ValueToValueMapTy &available) { - // List of loop indices in chunk from innermost to outermost - SmallVector indices; - // List of cumulative indices in chunk from innermost to outermost - // where limit[i] = prod(loop limit[0..i]) - SmallVector limits; - - // Iterate from innermost loop to outermost loop within a chunk - for (size_t i = 0; i < containedloops.size(); ++i) { - const auto &pair = containedloops[i]; - - const auto &idx = pair.first; - Value *var = idx.var; - - // In the SingleIteration, var may be null (since there's no legal phinode) - // In that case the current iteration is simply the constnat Zero - if (idx.var == nullptr) - var = ConstantInt::get(Type::getInt64Ty(newFunc->getContext()), 0); - else if (available.count(var)) { - var = available.find(var)->second; - } else if (!inForwardPass) { - var = v.CreateLoad(idx.var->getType(), idx.antivaralloc); - } else { - var = idx.var; - } - if (idx.offset) { - var = v.CreateAdd(var, lookupM(idx.offset, v), "", /*NUW*/ true, - /*NSW*/ true); - } - - indices.push_back(var); - Value *lim = pair.second; - assert(lim); - if (limits.size() == 0) { - limits.push_back(lim); - } else { - limits.push_back(v.CreateMul(limits.back(), lim, "", - /*NUW*/ true, /*NSW*/ true)); - } - } - - assert(indices.size() > 0); - - // Compute the index into the pointer - Value *idx = indices[0]; - for (unsigned ind = 1; ind < indices.size(); ++ind) { - idx = v.CreateAdd(idx, - v.CreateMul(indices[ind], limits[ind - 1], "", - /*NUW*/ true, /*NSW*/ true), - "", /*NUW*/ true, /*NSW*/ true); - } - return idx; -} - -/// Given a LimitContext ctx, representing a location inside a loop nest, -/// break each of the loops up into chunks of loops where each chunk's number -/// of iterations can be computed at the chunk preheader. Every dynamic loop -/// defines the start of a chunk. SubLimitType is a vector of chunk objects. -/// More specifically it is a vector of { # iters in a Chunk (sublimit), Chunk } -/// Each chunk object is a vector of loops contained within the chunk. -/// For every loop, this returns pair of the LoopContext and the limit of that -/// loop Both the vector of Chunks and vector of Loops within a Chunk go from -/// innermost loop to outermost loop. -CacheUtility::SubLimitType CacheUtility::getSubLimits(bool inForwardPass, - IRBuilder<> *RB, - LimitContext ctx, - Value *extraSize) { - // Store the LoopContext's in InnerMost => Outermost order - SmallVector contexts; - - // Given a ``SingleIteration'' Limit Context, return a chunking of - // one loop with size 1, and header/preheader of the BasicBlock - // This is done to create a context for a block outside a loop - // and is part of an experimental mechanism for merging stores - // into a unified memcpy - if (ctx.ForceSingleIteration) { - LoopContext idx; - auto subctx = ctx.Block; - auto zero = ConstantInt::get(Type::getInt64Ty(newFunc->getContext()), 0); - // The iteration count is always zero so we can set it as such - idx.var = nullptr; // = zero; - idx.incvar = nullptr; - idx.antivaralloc = nullptr; - idx.trueLimit = zero; - idx.maxLimit = zero; - idx.header = subctx; - idx.preheader = subctx; - idx.dynamic = false; - idx.parent = nullptr; - idx.exitBlocks = {}; - idx.offset = nullptr; - idx.allocLimit = nullptr; - contexts.push_back(idx); - } - - for (BasicBlock *blk = ctx.Block; blk != nullptr;) { - LoopContext idx; - if (!getContext(blk, idx, ctx.ReverseLimit)) { - break; - } - contexts.emplace_back(std::move(idx)); - blk = idx.preheader; - } - - // Legal preheaders for loop i (indexed from inner => outer) - SmallVector allocationPreheaders(contexts.size(), nullptr); - // Limit of loop i (indexed from inner => outer) - SmallVector limits(contexts.size(), nullptr); - - // Iterate from outermost loop to innermost loop - for (int i = contexts.size() - 1; i >= 0; --i) { - // The outermost loop's preheader is the preheader directly - // outside the loop nest - if ((unsigned)i == contexts.size() - 1) { - allocationPreheaders[i] = contexts[i].preheader; - } else if (!contexts[i].maxLimit) { - // For dynamic loops, the preheader is now forced to be the preheader - // of that loop - allocationPreheaders[i] = contexts[i].preheader; - } else { - // Otherwise try to use the preheader of the loop just outside this - // one to allocate all iterations across both loops together - allocationPreheaders[i] = allocationPreheaders[i + 1]; - } - - // Dynamic loops are considered to have a limit of one for allocation - // purposes This is because we want to allocate 1 x (# of iterations inside - // chunk) inside every dynamic iteration - if (!contexts[i].maxLimit) { - limits[i] = - ConstantInt::get(Type::getInt64Ty(ctx.Block->getContext()), 1); - } else { - // Map of previous induction variables we are allowed to use as part - // of the computation of the number of iterations in this chunk - ValueToValueMapTy prevMap; - - // Iterate from outermost loop down - for (int j = contexts.size() - 1;; --j) { - // If the preheader allocating memory for loop i - // is distinct from this preheader, we are therefore allocating - // memory in a different chunk. We can use induction variables - // from chunks outside us to compute loop bounds so add it to the - // map - if (allocationPreheaders[i] != contexts[j].preheader) { - prevMap[contexts[j].var] = contexts[j].var; - } else { - break; - } - } - - IRBuilder<> allocationBuilder(&allocationPreheaders[i]->back()); - Value *limitMinus1 = nullptr; - - Value *limit = contexts[i].maxLimit; - if (contexts[i].allocLimit) - limit = contexts[i].allocLimit; - - // Attempt to compute the limit of this loop at the corresponding - // allocation preheader. This is null if it was not legal to compute - limitMinus1 = unwrapM(limit, allocationBuilder, prevMap, - UnwrapMode::AttemptFullUnwrap); - - // We have a loop with static bounds, but whose limit is not available - // to be computed at the current loop preheader (such as the innermost - // loop of triangular iteration domain) Handle this case like a dynamic - // loop and create a new chunk. - if (limitMinus1 == nullptr) { - EmitWarning("NoOuterLimit", *cast(&*limit), - "Could not compute outermost loop limit by moving value ", - *limit, " computed at block", contexts[i].header->getName(), - " function ", contexts[i].header->getParent()->getName()); - allocationPreheaders[i] = contexts[i].preheader; - allocationBuilder.SetInsertPoint(&allocationPreheaders[i]->back()); - limitMinus1 = unwrapM(limit, allocationBuilder, prevMap, - UnwrapMode::AttemptFullUnwrap); - if (limitMinus1 == nullptr) { - llvm::errs() << *contexts[i].preheader->getParent() << "\n"; - llvm::errs() << "block: " << *allocationPreheaders[i] << "\n"; - llvm::errs() << "limit: " << *limit << "\n"; - } - assert(limitMinus1 != nullptr); - } else if (i == 0 && extraSize && - unwrapM(extraSize, allocationBuilder, prevMap, - UnwrapMode::AttemptFullUnwrap) == nullptr) { - EmitWarning( - "NoOuterLimit", *cast(extraSize), newFunc, - cast(extraSize)->getParent(), - "Could not compute outermost loop limit by moving extraSize value ", - *extraSize, " computed at block", contexts[i].header->getName(), - " function ", contexts[i].header->getParent()->getName()); - allocationPreheaders[i] = contexts[i].preheader; - allocationBuilder.SetInsertPoint(&allocationPreheaders[i]->back()); - } - assert(limitMinus1 != nullptr); - - ValueToValueMapTy reverseMap; - // Iterate from outermost loop down - for (int j = contexts.size() - 1;; --j) { - // If the preheader allocating memory for loop i - // is distinct from this preheader, we are therefore allocating - // memory in a different chunk. We can use induction variables - // from chunks outside us to compute loop bounds so add it to the - // map - if (allocationPreheaders[i] != contexts[j].preheader) { - if (!inForwardPass) { - reverseMap[contexts[j].var] = RB->CreateLoad( - contexts[j].var->getType(), contexts[j].antivaralloc); - } - } else { - break; - } - } - - // We now need to compute the actual limit as opposed to the limit - // minus one. - if (inForwardPass) { - // For efficiency, avoid doing this multiple times for - // the same pair by caching inside - // of LimitCache. - auto &map = LimitCache[limitMinus1]; - auto found = map.find(allocationPreheaders[i]); - if (found != map.end() && found->second != nullptr) { - limits[i] = found->second; - } else { - limits[i] = map[allocationPreheaders[i]] = - allocationBuilder.CreateNUWAdd( - limitMinus1, ConstantInt::get(limitMinus1->getType(), 1)); - } - } else { - Value *lim = unwrapM(limitMinus1, *RB, reverseMap, - UnwrapMode::AttemptFullUnwrapWithLookup, - allocationPreheaders[i]); - if (!lim) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *limitMinus1 << "\n"; - } - assert(lim); - limits[i] = RB->CreateNUWAdd(lim, ConstantInt::get(lim->getType(), 1)); - } - } - } - - SubLimitType sublimits; - - // Total number of iterations of current chunk of loops - Value *size = nullptr; - // Loops inside current chunk (stored innermost to outermost) - SmallVector, 3> lims; - - // Iterate from innermost to outermost loops - for (unsigned i = 0; i < contexts.size(); ++i) { - IRBuilder<> allocationBuilder(&allocationPreheaders[i]->back()); - lims.push_back(std::make_pair(contexts[i], limits[i])); - // Compute the cumulative size - if (size == nullptr) { - // If starting with no cumulative size, this is the cumulative size - size = limits[i]; - } else if (!inForwardPass) { - size = RB->CreateMul(size, limits[i], "", - /*NUW*/ true, /*NSW*/ true); - } else { - // Otherwise new size = old size * limits[i]; - auto cidx = std::make_tuple(size, limits[i], allocationPreheaders[i]); - if (SizeCache.find(cidx) == SizeCache.end()) { - SizeCache[cidx] = - allocationBuilder.CreateMul(size, limits[i], "", - /*NUW*/ true, /*NSW*/ true); - } - size = SizeCache[cidx]; - } - - // If we are starting a new chunk in the next iteration - // push this chunk to sublimits and clear the cumulative calculations - if ((i + 1 < contexts.size()) && - (allocationPreheaders[i] != allocationPreheaders[i + 1])) { - sublimits.push_back(std::make_pair(size, lims)); - size = nullptr; - lims.clear(); - } - } - - // For any remaining loop chunks, add them to the list - if (size != nullptr) { - sublimits.push_back(std::make_pair(size, lims)); - lims.clear(); - } - - return sublimits; -} - -/// Given an allocation defined at a particular ctx, store the value val -/// in the cache at the location defined in the given builder -void CacheUtility::storeInstructionInCache(LimitContext ctx, - IRBuilder<> &BuilderM, Value *val, - AllocaInst *cache, MDNode *TBAA) { - assert(BuilderM.GetInsertBlock()->getParent() == newFunc); -#ifndef NDEBUG - if (auto inst = dyn_cast(val)) - assert(inst->getParent()->getParent() == newFunc); -#endif - IRBuilder<> v(BuilderM.GetInsertBlock()); - v.SetInsertPoint(BuilderM.GetInsertBlock(), BuilderM.GetInsertPoint()); - v.setFastMathFlags(getFast()); - - // Note for dynamic loops where the allocation is stored somewhere inside - // the loop, we must ensure that we load the allocation after actually - // storing the allocation itself. - // To simplify things and ensure we always store after a - // potential realloc occurs in this loop, we put our store after - // any existing stores in the loop. - // This is okay as there should be no load to the cache in the same block - // where this instruction is defined as we will just use this instruction - // TODO check that the store is actually aliasing/related - if (BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) { - for (auto I = BuilderM.GetInsertBlock()->rbegin(), - E = BuilderM.GetInsertBlock()->rend(); - I != E; ++I) { - if (&*I == &*BuilderM.GetInsertPoint()) - break; - if (auto si = dyn_cast(&*I)) { - auto ni = getNextNonDebugInstructionOrNull(si); - if (ni != nullptr) { - v.SetInsertPoint(ni); - } else { - v.SetInsertPoint(si->getParent()); - } - } - } - } - - bool isi1 = val->getType()->isIntegerTy() && - cast(val->getType())->getBitWidth() == 1; - Value *loc = getCachePointer(val->getType(), - /*inForwardPass*/ true, v, ctx, cache, - /*storeInInstructionsMap*/ true, - /*available*/ llvm::ValueToValueMapTy(), - /*extraSize*/ nullptr); - - Value *tostore = val; - - // If we are doing the efficient bool cache, the actual value - // we want to store needs to have the existing surrounding bits - // set appropriately - if (EfficientBoolCache && isi1) { - if (auto gep = dyn_cast(loc)) { - auto bo = cast(*gep->idx_begin()); - assert(bo->getOpcode() == BinaryOperator::LShr); - auto subidx = v.CreateAnd( - v.CreateTrunc(bo->getOperand(0), - Type::getInt8Ty(cache->getContext())), - ConstantInt::get(Type::getInt8Ty(cache->getContext()), 7)); - auto mask = v.CreateNot(v.CreateShl( - ConstantInt::get(Type::getInt8Ty(cache->getContext()), 1), subidx)); - - Value *loadChunk = v.CreateLoad(mask->getType(), loc); - auto cleared = v.CreateAnd(loadChunk, mask); - - auto toset = v.CreateShl( - v.CreateZExt(val, Type::getInt8Ty(cache->getContext())), subidx); - tostore = v.CreateOr(cleared, toset); - assert(tostore->getType() == mask->getType()); - } - } - -#if LLVM_VERSION_MAJOR < 17 - if (tostore->getContext().supportsTypedPointers()) { - if (tostore->getType() != loc->getType()->getPointerElementType()) { - llvm::errs() << "val: " << *val << "\n"; - llvm::errs() << "tostore: " << *tostore << "\n"; - llvm::errs() << "loc: " << *loc << "\n"; - } - assert(tostore->getType() == loc->getType()->getPointerElementType()); - } -#endif - - StoreInst *storeinst = v.CreateStore(tostore, loc); - - // If the value stored doesnt change (per efficient bool cache), - // mark it as invariant - if (tostore == val) { - if (ValueInvariantGroups.find(cache) == ValueInvariantGroups.end()) { - MDNode *invgroup = MDNode::getDistinct(cache->getContext(), {}); - ValueInvariantGroups[cache] = invgroup; - } - storeinst->setMetadata(LLVMContext::MD_invariant_group, - ValueInvariantGroups[cache]); - } - - // Set alignment - ConstantInt *byteSizeOfType = - ConstantInt::get(Type::getInt64Ty(cache->getContext()), - ctx.Block->getParent() - ->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(val->getType()) / - 8); - unsigned align = getCacheAlignment((unsigned)byteSizeOfType->getZExtValue()); - storeinst->setMetadata(LLVMContext::MD_tbaa, TBAA); - storeinst->setAlignment(Align(align)); - scopeInstructions[cache].push_back(storeinst); - for (auto post : PostCacheStore(storeinst, v)) { - scopeInstructions[cache].push_back(post); - } -} - -/// Given an allocation defined at a particular ctx, store the instruction -/// in the cache right after the instruction is executed -void CacheUtility::storeInstructionInCache(LimitContext ctx, - llvm::Instruction *inst, - llvm::AllocaInst *cache, - llvm::MDNode *TBAA) { - assert(ctx.Block); - assert(inst); - assert(cache); - - // Find the correct place to issue the store - IRBuilder<> v(inst->getParent()); - // If this is a PHINode, we need to store after all phinodes, - // otherwise just after inst sufficies - if (&*inst->getParent()->rbegin() != inst) { - auto pn = dyn_cast(inst); - Instruction *putafter = (pn && pn->getNumIncomingValues() > 0) - ? (inst->getParent()->getFirstNonPHI()) - : getNextNonDebugInstruction(inst); - assert(putafter); - v.SetInsertPoint(putafter); - } - v.setFastMathFlags(getFast()); - storeInstructionInCache(ctx, v, inst, cache, TBAA); -} - -/// Given an allocation specified by the LimitContext ctx and cache, compute a -/// pointer that can hold the underlying type being cached. This value should be -/// computed at BuilderM. Optionally, instructions needed to generate this -/// pointer can be stored in scopeInstructions -Value *CacheUtility::getCachePointer(llvm::Type *T, bool inForwardPass, - IRBuilder<> &BuilderM, LimitContext ctx, - Value *cache, bool storeInInstructionsMap, - const ValueToValueMapTy &available, - Value *extraSize) { - assert(ctx.Block); - assert(cache); - auto sublimits = getSubLimits(inForwardPass, &BuilderM, ctx, extraSize); - - Value *next = cache; - assert(next->getType()->isPointerTy()); - - SmallVector types = {T}; - bool isi1 = T->isIntegerTy() && cast(T)->getBitWidth() == 1; - if (EfficientBoolCache && isi1 && sublimits.size() != 0) - types[0] = Type::getInt8Ty(T->getContext()); - auto i64 = Type::getInt64Ty(T->getContext()); - for (size_t i = 0; i < sublimits.size(); ++i) { - Type *allocType; - { - BasicBlock *BB = - BasicBlock::Create(newFunc->getContext(), "entry", newFunc); - IRBuilder<> B(BB); - auto P = B.CreatePHI(i64, 1); - - CallInst *malloccall; - Instruction *Zero; - allocType = cast(CreateAllocation(B, types.back(), P, - "tmpfortypecalc", - &malloccall, &Zero) - ->getType()); - for (auto &I : make_early_inc_range(reverse(*BB))) - I.eraseFromParent(); - - BB->eraseFromParent(); - } - types.push_back(allocType); - } - - // Iterate from outermost loop to innermost loop - for (int i = sublimits.size() - 1; i >= 0; i--) { - // Lookup the next allocation pointer - next = BuilderM.CreateLoad(types[i + 1], next); - if (storeInInstructionsMap && isa(cache)) - scopeInstructions[cast(cache)].push_back( - cast(next)); - - if (!next->getType()->isPointerTy()) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << "cache: " << *cache << "\n"; - llvm::errs() << "next: " << *next << "\n"; - assert(next->getType()->isPointerTy()); - } - - // Set appropriate invairant lookup flags - if (CachePointerInvariantGroups.find(std::make_pair(cache, i)) == - CachePointerInvariantGroups.end()) { - MDNode *invgroup = MDNode::getDistinct(cache->getContext(), {}); - CachePointerInvariantGroups[std::make_pair(cache, i)] = invgroup; - } - cast(next)->setMetadata( - LLVMContext::MD_invariant_group, - CachePointerInvariantGroups[std::make_pair(cache, i)]); - - // Set dereferenceable and alignment flags - ConstantInt *byteSizeOfType = ConstantInt::get( - Type::getInt64Ty(cache->getContext()), - newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits( - next->getType()) / - 8); - cast(next)->setMetadata( - LLVMContext::MD_dereferenceable, - MDNode::get( - cache->getContext(), - ArrayRef(ConstantAsMetadata::get(byteSizeOfType)))); - unsigned align = - getCacheAlignment((unsigned)byteSizeOfType->getZExtValue()); - cast(next)->setAlignment(Align(align)); - - const auto &containedloops = sublimits[i].second; - - if (containedloops.size() > 0) { - Value *idx = computeIndexOfChunk(inForwardPass, BuilderM, containedloops, - available); - if (EfficientBoolCache && isi1 && i == 0) - idx = BuilderM.CreateLShr( - idx, ConstantInt::get(Type::getInt64Ty(newFunc->getContext()), 3)); - if (i == 0 && extraSize) { - Value *es = lookupM(extraSize, BuilderM); - assert(es); - idx = BuilderM.CreateMul(idx, es, "", /*NUW*/ true, /*NSW*/ true); - } - next = BuilderM.CreateGEP(types[i], next, idx); - cast(next)->setIsInBounds(true); - if (storeInInstructionsMap && isa(cache)) - scopeInstructions[cast(cache)].push_back( - cast(next)); - } - assert(next->getType()->isPointerTy()); - } - return next; -} - -/// Perform the final load from the cache, applying requisite invariant -/// group and alignment -llvm::Value *CacheUtility::loadFromCachePointer(Type *T, - llvm::IRBuilder<> &BuilderM, - llvm::Value *cptr, - llvm::Value *cache) { - // Retrieve the actual result - auto result = BuilderM.CreateLoad(T, cptr); - - // Apply requisite invariant, alignment, etc - if (ValueInvariantGroups.find(cache) == ValueInvariantGroups.end()) { - MDNode *invgroup = MDNode::getDistinct(cache->getContext(), {}); - ValueInvariantGroups[cache] = invgroup; - } - CacheLookups.insert(result); - result->setMetadata(LLVMContext::MD_invariant_group, - ValueInvariantGroups[cache]); - ConstantInt *byteSizeOfType = ConstantInt::get( - Type::getInt64Ty(cache->getContext()), - newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits( - result->getType()) / - 8); - unsigned align = getCacheAlignment((unsigned)byteSizeOfType->getZExtValue()); - result->setAlignment(Align(align)); - - return result; -} - -/// Given an allocation specified by the LimitContext ctx and cache, lookup the -/// underlying cached value. -Value *CacheUtility::lookupValueFromCache( - Type *T, bool inForwardPass, IRBuilder<> &BuilderM, LimitContext ctx, - Value *cache, bool isi1, const ValueToValueMapTy &available, - Value *extraSize, Value *extraOffset) { - // Get the underlying cache pointer - auto cptr = - getCachePointer(T, inForwardPass, BuilderM, ctx, cache, - /*storeInInstructionsMap*/ false, available, extraSize); - - // Optionally apply the additional offset - if (extraOffset) { - cptr = BuilderM.CreateGEP(T, cptr, extraOffset); - cast(cptr)->setIsInBounds(true); - } - - Value *result = loadFromCachePointer(T, BuilderM, cptr, cache); - - // If using the efficient bool cache, do the corresponding - // mask and shift to retrieve the actual value - if (EfficientBoolCache && isi1) { - if (auto gep = dyn_cast(cptr)) { - auto bo = cast(*gep->idx_begin()); - assert(bo->getOpcode() == BinaryOperator::LShr); - Value *res = BuilderM.CreateLShr( - result, - BuilderM.CreateAnd( - BuilderM.CreateTrunc(bo->getOperand(0), - Type::getInt8Ty(cache->getContext())), - ConstantInt::get(Type::getInt8Ty(cache->getContext()), 7))); - return BuilderM.CreateTrunc(res, Type::getInt1Ty(result->getContext())); - } - } - return result; -} diff --git a/enzyme/Enzyme/CacheUtility.h b/enzyme/Enzyme/CacheUtility.h deleted file mode 100644 index 28450c2360f5..000000000000 --- a/enzyme/Enzyme/CacheUtility.h +++ /dev/null @@ -1,433 +0,0 @@ -//===- CacheUtility.h - Caching values in the forward pass for later use -//---===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file declares a base helper class CacheUtility that manages the cache -// of values from the forward pass for later use. -// -//===----------------------------------------------------------------------===// - -#ifndef ENZYME_CACHE_UTILITY_H -#define ENZYME_CACHE_UTILITY_H - -#include -#if LLVM_VERSION_MAJOR >= 16 -#define private public -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#undef private -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/IR/Instructions.h" - -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/Dominators.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Transforms/Utils/ValueMapper.h" - -#include "FunctionUtils.h" -#include "MustExitScalarEvolution.h" - -extern "C" { -/// Pack 8 bools together in a single byte -extern llvm::cl::opt EfficientBoolCache; - -extern llvm::cl::opt EnzymeZeroCache; -} - -/// Container for all loop information to synthesize gradients -struct LoopContext { - /// Canonical induction variable of the loop - llvm::AssertingVH var; - - /// Increment of the induction - llvm::AssertingVH incvar; - - /// Allocation of induction variable of reverse pass - llvm::AssertingVH antivaralloc; - - /// Header of this loop - llvm::BasicBlock *header; - - /// Preheader of this loop - llvm::BasicBlock *preheader; - - /// Whether this loop has a statically analyzable number of iterations - bool dynamic; - - /// limit is last value of a canonical induction variable - /// iters is number of times loop is run (thus iters = limit + 1) - AssertingReplacingVH maxLimit; - - AssertingReplacingVH trueLimit; - - /// An offset to add to the index when getting the cache pointer. - AssertingReplacingVH offset; - - /// An overriding allocation limit size. - AssertingReplacingVH allocLimit; - - /// All blocks this loop exits too - llvm::SmallPtrSet exitBlocks; - - /// Parent loop of this loop - llvm::Loop *parent; -}; -static inline bool operator==(const LoopContext &lhs, const LoopContext &rhs) { - return lhs.parent == rhs.parent; -} - -/// Modes of potential unwraps -enum class UnwrapMode { - // It is already known that it is legal to fully unwrap - // this instruction. This means unwrap this instruction, - // its operands, etc. However, this will stop at known - // cached available from a tape. - LegalFullUnwrap, - // Unlike LegalFullUnwrap, this will unwrap through a tape - LegalFullUnwrapNoTapeReplace, - // Attempt to fully unwrap this, looking up whenever it - // is not legal to unwrap - AttemptFullUnwrapWithLookup, - // Attempt to fully unwrap this - AttemptFullUnwrap, - // Unwrap the current instruction but not its operand - AttemptSingleUnwrap, -}; - -static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - UnwrapMode mode) { - switch (mode) { - case UnwrapMode::LegalFullUnwrap: - os << "LegalFullUnwrap"; - break; - case UnwrapMode::LegalFullUnwrapNoTapeReplace: - os << "LegalFullUnwrapNoTapeReplace"; - break; - case UnwrapMode::AttemptFullUnwrapWithLookup: - os << "AttemptFullUnwrapWithLookup"; - break; - case UnwrapMode::AttemptFullUnwrap: - os << "AttemptFullUnwrap"; - break; - case UnwrapMode::AttemptSingleUnwrap: - os << "AttemptSingleUnwrap"; - break; - } - return os; -} - -class CacheUtility { -public: - /// The function whose instructions we are caching - llvm::Function *const newFunc; - - /// Various analysis results of newFunc - llvm::TargetLibraryInfo &TLI; - llvm::DominatorTree DT; - -protected: - llvm::LoopInfo LI; - llvm::AssumptionCache AC; - MustExitScalarEvolution SE; - -public: - // Helper basicblock where all new allocations will be added to - // This includes allocations for cache variables - llvm::BasicBlock *inversionAllocs; - -protected: - CacheUtility(llvm::TargetLibraryInfo &TLI, llvm::Function *newFunc) - : newFunc(newFunc), TLI(TLI), DT(*newFunc), LI(DT), AC(*newFunc), - SE(*newFunc, TLI, AC, DT, LI) { - inversionAllocs = llvm::BasicBlock::Create(newFunc->getContext(), - "allocsForInversion", newFunc); - } - -public: - virtual ~CacheUtility(); - -protected: - /// Map of Loop to requisite loop information needed for AD (forward/reverse - /// induction/etc) - std::map loopContexts; - -public: - /// Given a BasicBlock BB in newFunc, set loopContext to the relevant - /// contained loop and return true. If BB is not in a loop, return false - bool getContext(llvm::BasicBlock *BB, LoopContext &loopContext, - bool ReverseLimit); - /// Return whether the given instruction is used as necessary as part of a - /// loop context This includes as the canonical induction variable or - /// increment - bool isInstructionUsedInLoopInduction(llvm::Instruction &I) { - for (auto &context : loopContexts) { - if (context.second.var == &I || context.second.incvar == &I || - context.second.maxLimit == &I || context.second.trueLimit == &I) { - return true; - } - } - return false; - } - - llvm::AllocaInst *getDynamicLoopLimit(llvm::Loop *L, - bool ReverseLimit = true); - - /// Print out all currently cached values - void dumpScope() { - llvm::errs() << "scope:\n"; - for (auto a : scopeMap) { - llvm::errs() << " scopeMap[" << *a.first << "] = " << *a.second.first - << " ctx:" << a.second.second.Block->getName() << "\n"; - } - llvm::errs() << "end scope\n"; - } - - unsigned getCacheAlignment(unsigned bsize) const { - if ((bsize & (bsize - 1)) == 0) { - if (bsize > 8) - return 8; - else - return bsize; - } else if (bsize > 0 && bsize % 8 == 0) { - return 8; - } else if (bsize > 0 && bsize % 4 == 0) { - return 4; - } else if (bsize > 0 && bsize % 2 == 0) { - return 2; - } else - return 1; - } - - /// Erase this instruction both from LLVM modules and any local - /// data-structures - virtual void erase(llvm::Instruction *I); - /// Replace this instruction both in LLVM modules and any local - /// data-structures - virtual void replaceAWithB(llvm::Value *A, llvm::Value *B, - bool storeInCache = false); - - // Context information to request calculation of loop limit information - struct LimitContext { - // Whether the limit needs to be accessible for a reverse pass - bool ReverseLimit; - - // A block inside of the loop, defining the location - llvm::BasicBlock *Block; - // Instead of getting the actual limits, return a limit of one - bool ForceSingleIteration; - - LimitContext(bool ReverseLimit, llvm::BasicBlock *Block, - bool ForceSingleIteration = false) - : ReverseLimit(ReverseLimit), Block(Block), - ForceSingleIteration(ForceSingleIteration) {} - }; - - /// Given a LimitContext ctx, representing a location inside a loop nest, - /// break each of the loops up into chunks of loops where each chunk's number - /// of iterations can be computed at the chunk preheader. Every dynamic loop - /// defines the start of a chunk. SubLimitType is a vector of chunk objects. - /// More specifically it is a vector of { # iters in a Chunk (sublimit), Chunk - /// } Each chunk object is a vector of loops contained within the chunk. For - /// every loop, this returns pair of the LoopContext and the limit of that - /// loop Both the vector of Chunks and vector of Loops within a Chunk go from - /// innermost loop to outermost loop. - typedef llvm::SmallVector, 4>>, - 0> - SubLimitType; - SubLimitType getSubLimits(bool inForwardPass, llvm::IRBuilder<> *RB, - LimitContext ctx, llvm::Value *extraSize = nullptr); - -private: - /// Internal data structure used by getSubLimit to avoid computing the same - /// loop limit multiple times if possible. Map's a desired limitMinus1 (see - /// getSubLimits) and the block the true limit requested to the value of the - /// limit accessible at that block - llvm::ValueMap> - LimitCache; - /// Internal data structure used by getSubLimit to avoid computing the - /// cumulative loop limit multiple times if possible. Map's a desired pair of - /// operands to be multiplied (see getSubLimits) and the block the cumulative - /// limit requested to the value of the limit accessible at that block This - /// cache is also shared with computeIndexOfChunk - std::map, - llvm::Value *> - SizeCache; - - /// Given a loop context, compute the corresponding index into said loop at - /// the IRBuilder<> - llvm::Value *computeIndexOfChunk( - bool inForwardPass, llvm::IRBuilder<> &v, - llvm::ArrayRef> containedloops, - const llvm::ValueToValueMapTy &available); - -private: - /// Given a cache allocation and an index denoting how many Chunks deep the - /// allocation is being indexed into, return the invariant metadata describing - /// used to describe loads/stores to the indexed pointer - /// Note that the cache allocation should either be an allocainst (if in - /// fwd/both) or an extraction from the tape - std::map, llvm::MDNode *> - CachePointerInvariantGroups; - /// Given a value being cached, return the invariant metadata of any - /// loads/stores to memory storing that value - std::map ValueInvariantGroups; - -protected: - /// A map of values being cached to their underlying allocation/limit context - std::map, LimitContext>> - scopeMap; - - /// A map of allocations to a vector of instruction used to create by the - /// allocation Keeping track of these values is useful for deallocation. This - /// is stored as a vector explicitly to order theses instructions in such a - /// way that they can be erased by iterating in reverse order. - std::map, 4>> - scopeInstructions; - - /// A map of allocations to a set of instructions which free memory as part of - /// the cache. - std::map>> - scopeFrees; - - /// A map of allocations to a set of instructions which allocate memory as - /// part of the cache - std::map, 4>> - scopeAllocs; - - /// Perform the final load from the cache, applying requisite invariant - /// group and alignment - llvm::Value *loadFromCachePointer(llvm::Type *T, llvm::IRBuilder<> &BuilderM, - llvm::Value *cptr, llvm::Value *cache); - -public: - /// Create a cache of Type T at the given LimitContext. If allocateInternal is - /// set this will allocate the requesite memory. If extraSize is set, - /// allocations will be a factor of extraSize larger - llvm::AllocaInst *createCacheForScope(LimitContext ctx, llvm::Type *T, - llvm::StringRef name, bool shouldFree, - bool allocateInternal = true, - llvm::Value *extraSize = nullptr); - - /// High-level utility to "unwrap" an instruction at a new location specified - /// by BuilderM. Depending on the mode, it will either just unwrap this - /// instruction, all of its instructions operands, and optionally lookup - /// values when it is not legal to unwrap. If a value cannot be unwrap'd at a - /// given location, this will null. This high-level utility should be - /// implemented based off the low-level caching infrastructure provided in - /// this class. - virtual llvm::Value * - unwrapM(llvm::Value *const val, llvm::IRBuilder<> &BuilderM, - const llvm::ValueToValueMapTy &available, UnwrapMode mode, - llvm::BasicBlock *scope = nullptr, bool permitCache = true) = 0; - - /// High-level utility to get the value an instruction at a new location - /// specified by BuilderM. Unlike unwrap, this function can never fail -- - /// falling back to creating a cache if necessary. This function is - /// prepopulated with a set of values that are already known to be available - /// and may contain optimizations for getting the value in more efficient ways - /// (e.g. unwrap'ing when legal, looking up equivalent values, etc). This - /// high-level utility should be implemented based off the low-level caching - /// infrastructure provided in this class. - virtual llvm::Value * - lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, - const llvm::ValueToValueMapTy &incoming_availalble = - llvm::ValueToValueMapTy(), - bool tryLegalityCheck = true, llvm::BasicBlock *scope = nullptr) = 0; - - virtual bool assumeDynamicLoopOfSizeOne(llvm::Loop *L) const = 0; - - /// If an allocation is requested to be freed, this subclass will be called to - /// chose how and where to free it. It is by default not implemented, falling - /// back to an error. Subclasses who want to free memory should implement this - /// function. - virtual llvm::CallInst * - freeCache(llvm::BasicBlock *forwardPreheader, const SubLimitType &antimap, - int i, llvm::AllocaInst *alloc, llvm::ConstantInt *byteSizeOfType, - llvm::Value *storeInto, llvm::MDNode *InvariantMD) { - assert(0 && "freeing cache not handled in this scenario"); - llvm_unreachable("freeing cache not handled in this scenario"); - } - - /// Given an allocation defined at a particular ctx, store the value val - /// in the cache at the location defined in the given builder - void storeInstructionInCache(LimitContext ctx, llvm::IRBuilder<> &BuilderM, - llvm::Value *val, llvm::AllocaInst *cache, - llvm::MDNode *TBAA = nullptr); - - /// Given an allocation defined at a particular ctx, store the instruction - /// in the cache right after the instruction is executed - void storeInstructionInCache(LimitContext ctx, llvm::Instruction *inst, - llvm::AllocaInst *cache, - llvm::MDNode *TBAA = nullptr); - - /// Given an allocation specified by the LimitContext ctx and cache, compute a - /// pointer that can hold the underlying type being cached. This value should - /// be computed at BuilderM. Optionally, instructions needed to generate this - /// pointer can be stored in scopeInstructions - llvm::Value *getCachePointer(llvm::Type *T, bool inForwardPass, - llvm::IRBuilder<> &BuilderM, LimitContext ctx, - llvm::Value *cache, bool storeInInstructionsMap, - const llvm::ValueToValueMapTy &available, - llvm::Value *extraSize); - - /// Given an allocation specified by the LimitContext ctx and cache, lookup - /// the underlying cached value. - llvm::Value *lookupValueFromCache(llvm::Type *T, bool inForwardPass, - llvm::IRBuilder<> &BuilderM, - LimitContext ctx, llvm::Value *cache, - bool isi1, - const llvm::ValueToValueMapTy &available, - llvm::Value *extraSize = nullptr, - llvm::Value *extraOffset = nullptr); - -protected: - // List of values loaded from the cache - llvm::SmallPtrSet CacheLookups; -}; - -// Create a new canonical induction variable of Type Ty for Loop L -// Return the variable and the increment instruction -std::pair -InsertNewCanonicalIV(llvm::Loop *L, llvm::Type *Ty, - const llvm::Twine &Name = "iv"); - -// Attempt to rewrite all phinode's in the loop in terms of the -// induction variable -void RemoveRedundantIVs( - llvm::BasicBlock *Header, llvm::PHINode *CanonicalIV, - llvm::Instruction *Increment, MustExitScalarEvolution &SE, - llvm::function_ref replacer, - llvm::function_ref eraser); -#endif diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp deleted file mode 100644 index 61933ca960b5..000000000000 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ /dev/null @@ -1,4245 +0,0 @@ -//===- CallDerivatives.cpp - Implementation of known call derivatives --===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the implementation of functions in instruction visitor -// AdjointGenerator that generate corresponding augmented forward pass code, -// and adjoints for certain known functions. -// -//===----------------------------------------------------------------------===// - -#include "AdjointGenerator.h" - -using namespace llvm; - -extern "C" { -void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t, - LLVMValueRef, uint8_t) = nullptr; -} - -void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, - llvm::StringRef funcName) { - using namespace llvm; - - assert(called); - assert(gutils->getWidth() == 1); - - IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call)); - BuilderZ.setFastMathFlags(getFast()); - - // MPI send / recv can only send float/integers - if (funcName == "PMPI_Isend" || funcName == "MPI_Isend" || - funcName == "PMPI_Irecv" || funcName == "MPI_Irecv") { - if (!gutils->isConstantInstruction(&call)) { - if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined) { - assert(!gutils->isConstantValue(call.getOperand(0))); - assert(!gutils->isConstantValue(call.getOperand(6))); - Value *d_req = gutils->invertPointerM(call.getOperand(6), BuilderZ); - if (d_req->getType()->isIntegerTy()) { - d_req = BuilderZ.CreateIntToPtr( - d_req, PointerType::getUnqual(getInt8PtrTy(call.getContext()))); - } - - auto i64 = Type::getInt64Ty(call.getContext()); - auto impi = getMPIHelper(call.getContext()); - - Value *impialloc = - CreateAllocation(BuilderZ, impi, ConstantInt::get(i64, 1)); - BuilderZ.SetInsertPoint(gutils->getNewFromOriginal(&call)); - - d_req = BuilderZ.CreateBitCast( - d_req, PointerType::getUnqual(impialloc->getType())); - Value *d_req_prev = BuilderZ.CreateLoad(impialloc->getType(), d_req); - BuilderZ.CreateStore( - BuilderZ.CreatePointerCast(d_req_prev, - getInt8PtrTy(call.getContext())), - getMPIMemberPtr(BuilderZ, impialloc, impi)); - BuilderZ.CreateStore(impialloc, d_req); - - if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") { - Value *tysize = - MPI_TYPE_SIZE(gutils->getNewFromOriginal(call.getOperand(2)), - BuilderZ, call.getType()); - - auto len_arg = BuilderZ.CreateZExtOrTrunc( - gutils->getNewFromOriginal(call.getOperand(1)), - Type::getInt64Ty(call.getContext())); - len_arg = BuilderZ.CreateMul( - len_arg, - BuilderZ.CreateZExtOrTrunc(tysize, - Type::getInt64Ty(call.getContext())), - "", true, true); - - Value *firstallocation = - CreateAllocation(BuilderZ, Type::getInt8Ty(call.getContext()), - len_arg, "mpirecv_malloccache"); - BuilderZ.CreateStore(firstallocation, getMPIMemberPtr( - BuilderZ, impialloc, impi)); - BuilderZ.SetInsertPoint(gutils->getNewFromOriginal(&call)); - } else { - Value *ibuf = gutils->invertPointerM(call.getOperand(0), BuilderZ); - if (ibuf->getType()->isIntegerTy()) - ibuf = - BuilderZ.CreateIntToPtr(ibuf, getInt8PtrTy(call.getContext())); - BuilderZ.CreateStore( - ibuf, getMPIMemberPtr(BuilderZ, impialloc, impi)); - } - - BuilderZ.CreateStore( - BuilderZ.CreateZExtOrTrunc( - gutils->getNewFromOriginal(call.getOperand(1)), i64), - getMPIMemberPtr(BuilderZ, impialloc, impi)); - - Value *dataType = gutils->getNewFromOriginal(call.getOperand(2)); - if (dataType->getType()->isIntegerTy()) - dataType = BuilderZ.CreateIntToPtr( - dataType, getInt8PtrTy(dataType->getContext())); - BuilderZ.CreateStore( - BuilderZ.CreatePointerCast(dataType, - getInt8PtrTy(call.getContext())), - getMPIMemberPtr(BuilderZ, impialloc, impi)); - - BuilderZ.CreateStore( - BuilderZ.CreateZExtOrTrunc( - gutils->getNewFromOriginal(call.getOperand(3)), i64), - getMPIMemberPtr(BuilderZ, impialloc, impi)); - - BuilderZ.CreateStore( - BuilderZ.CreateZExtOrTrunc( - gutils->getNewFromOriginal(call.getOperand(4)), i64), - getMPIMemberPtr(BuilderZ, impialloc, impi)); - - Value *comm = gutils->getNewFromOriginal(call.getOperand(5)); - if (comm->getType()->isIntegerTy()) - comm = BuilderZ.CreateIntToPtr(comm, - getInt8PtrTy(dataType->getContext())); - BuilderZ.CreateStore( - BuilderZ.CreatePointerCast(comm, getInt8PtrTy(call.getContext())), - getMPIMemberPtr(BuilderZ, impialloc, impi)); - - BuilderZ.CreateStore( - ConstantInt::get( - Type::getInt8Ty(impialloc->getContext()), - (funcName == "MPI_Isend" || funcName == "PMPI_Isend") - ? (int)MPI_CallType::ISEND - : (int)MPI_CallType::IRECV), - getMPIMemberPtr(BuilderZ, impialloc, impi)); - // TODO old - } - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - Type *statusType = nullptr; -#if LLVM_VERSION_MAJOR < 17 - if (Function *recvfn = called->getParent()->getFunction("PMPI_Wait")) { - auto statusArg = recvfn->arg_end(); - statusArg--; - if (auto PT = dyn_cast(statusArg->getType())) - statusType = PT->getPointerElementType(); - } - if (Function *recvfn = called->getParent()->getFunction("MPI_Wait")) { - auto statusArg = recvfn->arg_end(); - statusArg--; - if (auto PT = dyn_cast(statusArg->getType())) - statusType = PT->getPointerElementType(); - } -#endif - if (statusType == nullptr) { - statusType = ArrayType::get(Type::getInt8Ty(call.getContext()), 24); - llvm::errs() << " warning could not automatically determine mpi " - "status type, assuming [24 x i8]\n"; - } - Value *req = - lookup(gutils->getNewFromOriginal(call.getOperand(6)), Builder2); - Value *d_req = lookup( - gutils->invertPointerM(call.getOperand(6), Builder2), Builder2); - if (d_req->getType()->isIntegerTy()) { - d_req = - Builder2.CreateIntToPtr(d_req, getInt8PtrTy(call.getContext())); - } - auto impi = getMPIHelper(call.getContext()); - Type *helperTy = llvm::PointerType::getUnqual(impi); - Value *helper = - Builder2.CreatePointerCast(d_req, PointerType::getUnqual(helperTy)); - helper = Builder2.CreateLoad(helperTy, helper); - - auto i64 = Type::getInt64Ty(call.getContext()); - - Value *firstallocation; - firstallocation = Builder2.CreateLoad( - getInt8PtrTy(call.getContext()), - getMPIMemberPtr(Builder2, helper, impi)); - Value *len_arg = nullptr; - if (auto C = dyn_cast( - gutils->getNewFromOriginal(call.getOperand(1)))) { - len_arg = Builder2.CreateZExtOrTrunc(C, i64); - } else { - len_arg = Builder2.CreateLoad( - i64, getMPIMemberPtr(Builder2, helper, impi)); - } - Value *tysize = nullptr; - if (auto C = dyn_cast( - gutils->getNewFromOriginal(call.getOperand(2)))) { - tysize = C; - } else { - tysize = Builder2.CreateLoad( - getInt8PtrTy(call.getContext()), - getMPIMemberPtr(Builder2, helper, impi)); - } - - Value *prev; - prev = Builder2.CreateLoad( - getInt8PtrTy(call.getContext()), - getMPIMemberPtr(Builder2, helper, impi)); - - Builder2.CreateStore( - prev, Builder2.CreatePointerCast( - d_req, PointerType::getUnqual(prev->getType()))); - - assert(shouldFree()); - - assert(tysize); - tysize = MPI_TYPE_SIZE(tysize, Builder2, call.getType()); - - Value *args[] = {/*req*/ req, - /*status*/ IRBuilder<>(gutils->inversionAllocs) - .CreateAlloca(statusType)}; - FunctionCallee waitFunc = nullptr; - for (auto name : {"PMPI_Wait", "MPI_Wait"}) - if (Function *recvfn = called->getParent()->getFunction(name)) { - auto statusArg = recvfn->arg_end(); - statusArg--; - if (statusArg->getType()->isIntegerTy()) - args[1] = Builder2.CreatePtrToInt(args[1], statusArg->getType()); - else - args[1] = Builder2.CreateBitCast(args[1], statusArg->getType()); - waitFunc = recvfn; - break; - } - if (!waitFunc) { - Type *types[sizeof(args) / sizeof(*args)]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++) - types[i] = args[i]->getType(); - FunctionType *FT = FunctionType::get(call.getType(), types, false); - waitFunc = called->getParent()->getOrInsertFunction("MPI_Wait", FT); - } - assert(waitFunc); - - // Need to preserve the shadow Request (operand 6 in isend/irecv), - // which becomes operand 0 for iwait. - auto ReqDefs = gutils->getInvertedBundles( - &call, - {ValueType::None, ValueType::None, ValueType::None, ValueType::None, - ValueType::None, ValueType::None, ValueType::Shadow}, - Builder2, /*lookup*/ true); - - auto BufferDefs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::None, ValueType::None, - ValueType::None, ValueType::None, ValueType::None, - ValueType::None}, - Builder2, /*lookup*/ true); - - auto fcall = Builder2.CreateCall(waitFunc, args, ReqDefs); - fcall->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc())); - if (auto F = dyn_cast(waitFunc.getCallee())) - fcall->setCallingConv(F->getCallingConv()); - len_arg = Builder2.CreateMul( - len_arg, - Builder2.CreateZExtOrTrunc(tysize, - Type::getInt64Ty(Builder2.getContext())), - "", true, true); - if (funcName == "MPI_Irecv" || funcName == "PMPI_Irecv") { - auto val_arg = - ConstantInt::get(Type::getInt8Ty(Builder2.getContext()), 0); - auto volatile_arg = ConstantInt::getFalse(Builder2.getContext()); - assert(!gutils->isConstantValue(call.getOperand(0))); - auto dbuf = firstallocation; - Value *nargs[] = {dbuf, val_arg, len_arg, volatile_arg}; - Type *tys[] = {dbuf->getType(), len_arg->getType()}; - - auto memset = cast(Builder2.CreateCall( - getIntrinsicDeclaration(called->getParent(), Intrinsic::memset, - tys), - nargs, BufferDefs)); - memset->addParamAttr(0, Attribute::NonNull); - } else if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") { - assert(!gutils->isConstantValue(call.getOperand(0))); - Value *shadow = lookup( - gutils->invertPointerM(call.getOperand(0), Builder2), Builder2); - - // TODO add operand bundle (unless force inlined?) - DifferentiableMemCopyFloats(call, call.getOperand(0), firstallocation, - shadow, len_arg, Builder2, BufferDefs); - - if (shouldFree()) { - CreateDealloc(Builder2, firstallocation); - } - } else - assert(0 && "illegal mpi"); - - CreateDealloc(Builder2, helper); - } - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - IRBuilder<> Builder2(&call); - getForwardBuilder(Builder2); - - assert(!gutils->isConstantValue(call.getOperand(0))); - assert(!gutils->isConstantValue(call.getOperand(6))); - - Value *buf = gutils->invertPointerM(call.getOperand(0), Builder2); - Value *count = gutils->getNewFromOriginal(call.getOperand(1)); - Value *datatype = gutils->getNewFromOriginal(call.getOperand(2)); - Value *source = gutils->getNewFromOriginal(call.getOperand(3)); - Value *tag = gutils->getNewFromOriginal(call.getOperand(4)); - Value *comm = gutils->getNewFromOriginal(call.getOperand(5)); - Value *request = gutils->invertPointerM(call.getOperand(6), Builder2); - - Value *args[] = { - /*buf*/ buf, - /*count*/ count, - /*datatype*/ datatype, - /*source*/ source, - /*tag*/ tag, - /*comm*/ comm, - /*request*/ request, - }; - - auto Defs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal, ValueType::Primal, ValueType::Primal, - ValueType::Shadow}, - Builder2, /*lookup*/ false); - - auto callval = call.getCalledOperand(); - - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - if (funcName == "MPI_Wait" || funcName == "PMPI_Wait") { - Value *d_reqp = nullptr; - auto impi = getMPIHelper(call.getContext()); - if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined) { - Value *req = gutils->getNewFromOriginal(call.getOperand(0)); - Value *d_req = gutils->invertPointerM(call.getOperand(0), BuilderZ); - - if (req->getType()->isIntegerTy()) { - req = BuilderZ.CreateIntToPtr( - req, PointerType::getUnqual(getInt8PtrTy(call.getContext()))); - } - - Value *isNull = nullptr; - if (auto GV = gutils->newFunc->getParent()->getNamedValue( - "ompi_request_null")) { - Value *reql = BuilderZ.CreatePointerCast( - req, PointerType::getUnqual(GV->getType())); - reql = BuilderZ.CreateLoad(GV->getType(), reql); - isNull = BuilderZ.CreateICmpEQ(reql, GV); - } - - if (d_req->getType()->isIntegerTy()) { - d_req = BuilderZ.CreateIntToPtr( - d_req, PointerType::getUnqual(getInt8PtrTy(call.getContext()))); - } - - d_reqp = BuilderZ.CreateLoad( - PointerType::getUnqual(impi), - BuilderZ.CreatePointerCast( - d_req, PointerType::getUnqual(PointerType::getUnqual(impi)))); - if (isNull) - d_reqp = - CreateSelect(BuilderZ, isNull, - Constant::getNullValue(d_reqp->getType()), d_reqp); - if (auto I = dyn_cast(d_reqp)) - gutils->TapesToPreventRecomputation.insert(I); - d_reqp = gutils->cacheForReverse( - BuilderZ, d_reqp, getIndex(&call, CacheType::Tape, BuilderZ)); - } - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - assert(!gutils->isConstantValue(call.getOperand(0))); - Value *req = - lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2); - - if (Mode != DerivativeMode::ReverseModeCombined) { - d_reqp = BuilderZ.CreatePHI(PointerType::getUnqual(impi), 0); - d_reqp = gutils->cacheForReverse( - BuilderZ, d_reqp, getIndex(&call, CacheType::Tape, BuilderZ)); - } else - assert(d_reqp); - d_reqp = lookup(d_reqp, Builder2); - - Value *isNull = Builder2.CreateICmpEQ( - d_reqp, Constant::getNullValue(d_reqp->getType())); - - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - BasicBlock *nonnullBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_nonnull"); - BasicBlock *endBlock = gutils->addReverseBlock( - nonnullBlock, currentBlock->getName() + "_end", - /*fork*/ true, /*push*/ false); - - Builder2.CreateCondBr(isNull, endBlock, nonnullBlock); - Builder2.SetInsertPoint(nonnullBlock); - - Value *cache = Builder2.CreateLoad(impi, d_reqp); - - Value *args[] = { - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - req}; - Type *types[sizeof(args) / sizeof(*args) - 1]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args) - 1; i++) - types[i] = args[i]->getType(); - Function *dwait = getOrInsertDifferentialMPI_Wait( - *called->getParent(), types, call.getOperand(0)->getType()); - - // Need to preserve the shadow Request (operand 0 in wait). - // However, this doesn't end up preserving - // the underlying buffers for the adjoint. To rememdy, force inline. - auto cal = - Builder2.CreateCall(dwait, args, - gutils->getInvertedBundles( - &call, {ValueType::Shadow, ValueType::None}, - Builder2, /*lookup*/ true)); - cal->setCallingConv(dwait->getCallingConv()); - cal->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc())); - cal->addFnAttr(Attribute::AlwaysInline); - Builder2.CreateBr(endBlock); - { - auto found = gutils->reverseBlockToPrimal.find(endBlock); - assert(found != gutils->reverseBlockToPrimal.end()); - SmallVector &vec = - gutils->reverseBlocks[found->second]; - assert(vec.size()); - vec.push_back(endBlock); - } - Builder2.SetInsertPoint(endBlock); - } else if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - IRBuilder<> Builder2(&call); - getForwardBuilder(Builder2); - - assert(!gutils->isConstantValue(call.getOperand(0))); - - Value *request = gutils->invertPointerM(call.getArgOperand(0), Builder2); - Value *status = gutils->invertPointerM(call.getArgOperand(1), Builder2); - - if (request->getType()->isIntegerTy()) { - request = Builder2.CreateIntToPtr( - request, PointerType::getUnqual(getInt8PtrTy(call.getContext()))); - } - - Value *args[] = {/*request*/ request, - /*status*/ status}; - - auto Defs = gutils->getInvertedBundles( - &call, {ValueType::Shadow, ValueType::Shadow}, Builder2, - /*lookup*/ false); - - auto callval = call.getCalledOperand(); - - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - if (funcName == "MPI_Waitall" || funcName == "PMPI_Waitall") { - Value *d_reqp = nullptr; - auto impi = getMPIHelper(call.getContext()); - PointerType *reqType = PointerType::getUnqual(impi); - if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined) { - Value *count = gutils->getNewFromOriginal(call.getOperand(0)); - Value *req = gutils->getNewFromOriginal(call.getOperand(1)); - Value *d_req = gutils->invertPointerM(call.getOperand(1), BuilderZ); - - if (req->getType()->isIntegerTy()) { - req = BuilderZ.CreateIntToPtr( - req, PointerType::getUnqual(getInt8PtrTy(call.getContext()))); - } - - if (d_req->getType()->isIntegerTy()) { - d_req = BuilderZ.CreateIntToPtr( - d_req, PointerType::getUnqual(getInt8PtrTy(call.getContext()))); - } - - Function *dsave = getOrInsertDifferentialWaitallSave( - *gutils->oldFunc->getParent(), - {count->getType(), req->getType(), d_req->getType()}, reqType); - - d_reqp = BuilderZ.CreateCall(dsave, {count, req, d_req}); - cast(d_reqp)->setCallingConv(dsave->getCallingConv()); - cast(d_reqp)->setDebugLoc( - gutils->getNewFromOriginal(call.getDebugLoc())); - d_reqp = gutils->cacheForReverse( - BuilderZ, d_reqp, getIndex(&call, CacheType::Tape, BuilderZ)); - } - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - assert(!gutils->isConstantValue(call.getOperand(1))); - Value *count = - lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2); - Value *req_orig = - lookup(gutils->getNewFromOriginal(call.getOperand(1)), Builder2); - - if (Mode != DerivativeMode::ReverseModeCombined) { - d_reqp = BuilderZ.CreatePHI(PointerType::getUnqual(reqType), 0); - d_reqp = gutils->cacheForReverse( - BuilderZ, d_reqp, getIndex(&call, CacheType::Tape, BuilderZ)); - } - - d_reqp = lookup(d_reqp, Builder2); - - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - BasicBlock *loopBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_loop"); - BasicBlock *nonnullBlock = gutils->addReverseBlock( - loopBlock, currentBlock->getName() + "_nonnull"); - BasicBlock *eloopBlock = gutils->addReverseBlock( - nonnullBlock, currentBlock->getName() + "_eloop"); - BasicBlock *endBlock = - gutils->addReverseBlock(eloopBlock, currentBlock->getName() + "_end", - /*fork*/ true, /*push*/ false); - - Builder2.CreateCondBr( - Builder2.CreateICmpNE(count, - ConstantInt::get(count->getType(), 0, false)), - loopBlock, endBlock); - - Builder2.SetInsertPoint(loopBlock); - auto idx = Builder2.CreatePHI(count->getType(), 2); - idx->addIncoming(ConstantInt::get(count->getType(), 0, false), - currentBlock); - Value *inc = Builder2.CreateAdd( - idx, ConstantInt::get(count->getType(), 1, false), "", true, true); - idx->addIncoming(inc, eloopBlock); - - Value *idxs[] = {idx}; - Value *req = Builder2.CreateInBoundsGEP(reqType, req_orig, idxs); - Value *d_req = Builder2.CreateInBoundsGEP(reqType, d_reqp, idxs); - - d_req = Builder2.CreateLoad( - PointerType::getUnqual(impi), - Builder2.CreatePointerCast( - d_req, PointerType::getUnqual(PointerType::getUnqual(impi)))); - - Value *isNull = Builder2.CreateICmpEQ( - d_req, Constant::getNullValue(d_req->getType())); - - Builder2.CreateCondBr(isNull, eloopBlock, nonnullBlock); - Builder2.SetInsertPoint(nonnullBlock); - - Value *cache = Builder2.CreateLoad(impi, d_req); - - Value *args[] = { - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - getMPIMemberPtr(Builder2, cache, impi), - req}; - Type *types[sizeof(args) / sizeof(*args) - 1]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args) - 1; i++) - types[i] = args[i]->getType(); - Function *dwait = getOrInsertDifferentialMPI_Wait(*called->getParent(), - types, req->getType()); - // Need to preserve the shadow Request (operand 6 in isend/irecv), which - // becomes operand 0 for iwait. However, this doesn't end up preserving - // the underlying buffers for the adjoint. To remedy, force inline the - // function. - auto cal = Builder2.CreateCall( - dwait, args, - gutils->getInvertedBundles(&call, - {ValueType::None, ValueType::None, - ValueType::None, ValueType::None, - ValueType::None, ValueType::None, - ValueType::Shadow}, - Builder2, /*lookup*/ true)); - cal->setCallingConv(dwait->getCallingConv()); - cal->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc())); - cal->addFnAttr(Attribute::AlwaysInline); - Builder2.CreateBr(eloopBlock); - - Builder2.SetInsertPoint(eloopBlock); - Builder2.CreateCondBr(Builder2.CreateICmpEQ(inc, count), endBlock, - loopBlock); - { - auto found = gutils->reverseBlockToPrimal.find(endBlock); - assert(found != gutils->reverseBlockToPrimal.end()); - SmallVector &vec = - gutils->reverseBlocks[found->second]; - assert(vec.size()); - vec.push_back(endBlock); - } - Builder2.SetInsertPoint(endBlock); - if (shouldFree()) { - CreateDealloc(Builder2, d_reqp); - } - } else if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - IRBuilder<> Builder2(&call); - - assert(!gutils->isConstantValue(call.getOperand(1))); - - Value *count = gutils->getNewFromOriginal(call.getOperand(0)); - Value *array_of_requests = - gutils->invertPointerM(call.getOperand(1), Builder2); - if (array_of_requests->getType()->isIntegerTy()) { - array_of_requests = Builder2.CreateIntToPtr( - array_of_requests, - PointerType::getUnqual(getInt8PtrTy(call.getContext()))); - } - - Value *args[] = { - /*count*/ count, - /*array_of_requests*/ array_of_requests, - }; - - auto Defs = gutils->getInvertedBundles( - &call, - {ValueType::None, ValueType::None, ValueType::None, ValueType::None, - ValueType::None, ValueType::None, ValueType::Shadow}, - Builder2, /*lookup*/ false); - - auto callval = call.getCalledOperand(); - - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - if (funcName == "MPI_Send" || funcName == "MPI_Ssend" || - funcName == "PMPI_Send" || funcName == "PMPI_Ssend") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - bool forwardMode = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError; - - IRBuilder<> Builder2 = - forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent()); - if (forwardMode) { - getForwardBuilder(Builder2); - } else { - getReverseBuilder(Builder2); - } - - Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2); - if (!forwardMode) - shadow = lookup(shadow, Builder2); - if (shadow->getType()->isIntegerTy()) - shadow = - Builder2.CreateIntToPtr(shadow, getInt8PtrTy(call.getContext())); - - Type *statusType = nullptr; -#if LLVM_VERSION_MAJOR < 17 - if (called->getContext().supportsTypedPointers()) { - if (Function *recvfn = called->getParent()->getFunction("MPI_Recv")) { - auto statusArg = recvfn->arg_end(); - statusArg--; - if (auto PT = dyn_cast(statusArg->getType())) - statusType = PT->getPointerElementType(); - } else if (Function *recvfn = - called->getParent()->getFunction("PMPI_Recv")) { - auto statusArg = recvfn->arg_end(); - statusArg--; - if (auto PT = dyn_cast(statusArg->getType())) - statusType = PT->getPointerElementType(); - } - } -#endif - if (statusType == nullptr) { - statusType = ArrayType::get(Type::getInt8Ty(call.getContext()), 24); - llvm::errs() << " warning could not automatically determine mpi " - "status type, assuming [24 x i8]\n"; - } - - Value *count = gutils->getNewFromOriginal(call.getOperand(1)); - if (!forwardMode) - count = lookup(count, Builder2); - - Value *datatype = gutils->getNewFromOriginal(call.getOperand(2)); - if (!forwardMode) - datatype = lookup(datatype, Builder2); - - Value *src = gutils->getNewFromOriginal(call.getOperand(3)); - if (!forwardMode) - src = lookup(src, Builder2); - - Value *tag = gutils->getNewFromOriginal(call.getOperand(4)); - if (!forwardMode) - tag = lookup(tag, Builder2); - - Value *comm = gutils->getNewFromOriginal(call.getOperand(5)); - if (!forwardMode) - comm = lookup(comm, Builder2); - - if (forwardMode) { - Value *args[] = { - /*buf*/ shadow, - /*count*/ count, - /*datatype*/ datatype, - /*dest*/ src, - /*tag*/ tag, - /*comm*/ comm, - }; - - auto Defs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal, ValueType::Primal, ValueType::Primal}, - Builder2, /*lookup*/ false); - - auto callval = call.getCalledOperand(); - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - - Value *args[] = { - /*buf*/ NULL, - /*count*/ count, - /*datatype*/ datatype, - /*src*/ src, - /*tag*/ tag, - /*comm*/ comm, - /*status*/ - IRBuilder<>(gutils->inversionAllocs).CreateAlloca(statusType)}; - - Value *tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType()); - - auto len_arg = Builder2.CreateZExtOrTrunc( - args[1], Type::getInt64Ty(call.getContext())); - len_arg = - Builder2.CreateMul(len_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - - Value *firstallocation = - CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()), - len_arg, "mpirecv_malloccache"); - args[0] = firstallocation; - - Type *types[sizeof(args) / sizeof(*args)]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++) - types[i] = args[i]->getType(); - FunctionType *FT = FunctionType::get(call.getType(), types, false); - - Builder2.SetInsertPoint(Builder2.GetInsertBlock()); - - auto BufferDefs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::None, ValueType::None, ValueType::None, - ValueType::None, ValueType::None, ValueType::None}, - Builder2, /*lookup*/ true); - - auto fcall = Builder2.CreateCall( - called->getParent()->getOrInsertFunction("MPI_Recv", FT), args); - fcall->setCallingConv(call.getCallingConv()); - - DifferentiableMemCopyFloats(call, call.getOperand(0), firstallocation, - shadow, len_arg, Builder2, BufferDefs); - - if (shouldFree()) { - CreateDealloc(Builder2, firstallocation); - } - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - if (funcName == "MPI_Recv" || funcName == "PMPI_Recv") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - bool forwardMode = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError; - - IRBuilder<> Builder2 = - forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent()); - if (forwardMode) { - getForwardBuilder(Builder2); - } else { - getReverseBuilder(Builder2); - } - - Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2); - if (!forwardMode) - shadow = lookup(shadow, Builder2); - if (shadow->getType()->isIntegerTy()) - shadow = - Builder2.CreateIntToPtr(shadow, getInt8PtrTy(call.getContext())); - - Value *count = gutils->getNewFromOriginal(call.getOperand(1)); - if (!forwardMode) - count = lookup(count, Builder2); - - Value *datatype = gutils->getNewFromOriginal(call.getOperand(2)); - if (!forwardMode) - datatype = lookup(datatype, Builder2); - - Value *source = gutils->getNewFromOriginal(call.getOperand(3)); - if (!forwardMode) - source = lookup(source, Builder2); - - Value *tag = gutils->getNewFromOriginal(call.getOperand(4)); - if (!forwardMode) - tag = lookup(tag, Builder2); - - Value *comm = gutils->getNewFromOriginal(call.getOperand(5)); - if (!forwardMode) - comm = lookup(comm, Builder2); - - Value *args[] = { - shadow, count, datatype, source, tag, comm, - }; - - auto Defs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal, ValueType::Primal, ValueType::Primal, - ValueType::None}, - Builder2, /*lookup*/ !forwardMode); - - if (forwardMode) { - auto callval = call.getCalledOperand(); - - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - - Type *types[sizeof(args) / sizeof(*args)]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++) - types[i] = args[i]->getType(); - FunctionType *FT = FunctionType::get(call.getType(), types, false); - - auto fcall = Builder2.CreateCall( - called->getParent()->getOrInsertFunction("MPI_Send", FT), args, Defs); - fcall->setCallingConv(call.getCallingConv()); - - auto dst_arg = - Builder2.CreateBitCast(args[0], getInt8PtrTy(call.getContext())); - auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0); - auto len_arg = Builder2.CreateZExtOrTrunc( - args[1], Type::getInt64Ty(call.getContext())); - auto tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType()); - len_arg = - Builder2.CreateMul(len_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - auto volatile_arg = ConstantInt::getFalse(call.getContext()); - - Value *nargs[] = {dst_arg, val_arg, len_arg, volatile_arg}; - Type *tys[] = {dst_arg->getType(), len_arg->getType()}; - - auto MemsetDefs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::None, ValueType::None, ValueType::None, - ValueType::None, ValueType::None, ValueType::None}, - Builder2, /*lookup*/ true); - auto memset = cast(Builder2.CreateCall( - getIntrinsicDeclaration(gutils->newFunc->getParent(), - Intrinsic::memset, tys), - nargs)); - memset->addParamAttr(0, Attribute::NonNull); - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int root, - // MPI_Comm comm ) - // 1. if root, malloc intermediate buffer - // 2. reduce sum diff(buffer) into intermediate - // 3. if root, set shadow(buffer) = intermediate [memcpy] then free - // 3-e. else, set shadow(buffer) = 0 [memset] - if (funcName == "MPI_Bcast") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - bool forwardMode = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError; - - IRBuilder<> Builder2 = - forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent()); - if (forwardMode) { - getForwardBuilder(Builder2); - } else { - getReverseBuilder(Builder2); - } - - Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2); - if (!forwardMode) - shadow = lookup(shadow, Builder2); - if (shadow->getType()->isIntegerTy()) - shadow = - Builder2.CreateIntToPtr(shadow, getInt8PtrTy(call.getContext())); - - ConcreteType CT = TR.firstPointer(1, call.getOperand(0), &call); - auto MPI_OP_type = getInt8PtrTy(call.getContext()); - Type *MPI_OP_Ptr_type = PointerType::getUnqual(MPI_OP_type); - - Value *count = gutils->getNewFromOriginal(call.getOperand(1)); - if (!forwardMode) - count = lookup(count, Builder2); - Value *datatype = gutils->getNewFromOriginal(call.getOperand(2)); - if (!forwardMode) - datatype = lookup(datatype, Builder2); - Value *root = gutils->getNewFromOriginal(call.getOperand(3)); - if (!forwardMode) - root = lookup(root, Builder2); - - Value *comm = gutils->getNewFromOriginal(call.getOperand(4)); - if (!forwardMode) - comm = lookup(comm, Builder2); - - if (forwardMode) { - Value *args[] = { - /*buffer*/ shadow, - /*count*/ count, - /*datatype*/ datatype, - /*root*/ root, - /*comm*/ comm, - }; - - auto Defs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal, ValueType::Primal}, - Builder2, /*lookup*/ false); - - auto callval = call.getCalledOperand(); - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - - Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType()); - Value *tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType()); - - auto len_arg = Builder2.CreateZExtOrTrunc( - count, Type::getInt64Ty(call.getContext())); - len_arg = - Builder2.CreateMul(len_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - - // 1. if root, malloc intermediate buffer, else undef - PHINode *buf; - - { - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - BasicBlock *rootBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_root", gutils->newFunc); - BasicBlock *mergeBlock = gutils->addReverseBlock( - rootBlock, currentBlock->getName() + "_post", gutils->newFunc); - - Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock, - mergeBlock); - - Builder2.SetInsertPoint(rootBlock); - - Value *rootbuf = - CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()), - len_arg, "mpireduce_malloccache"); - Builder2.CreateBr(mergeBlock); - - Builder2.SetInsertPoint(mergeBlock); - - buf = Builder2.CreatePHI(rootbuf->getType(), 2); - buf->addIncoming(rootbuf, rootBlock); - buf->addIncoming(UndefValue::get(buf->getType()), currentBlock); - } - - // Need to preserve the shadow buffer. - auto BufferDefs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal, ValueType::Primal}, - Builder2, /*lookup*/ true); - - // 2. reduce sum diff(buffer) into intermediate - { - // int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, - // MPI_Datatype datatype, - // MPI_Op op, int root, MPI_Comm comm) - Value *args[] = { - /*sendbuf*/ shadow, - /*recvbuf*/ buf, - /*count*/ count, - /*datatype*/ datatype, - /*op (MPI_SUM)*/ - getOrInsertOpFloatSum(*gutils->newFunc->getParent(), - MPI_OP_Ptr_type, MPI_OP_type, CT, - root->getType(), Builder2), - /*int root*/ root, - /*comm*/ comm, - }; - Type *types[sizeof(args) / sizeof(*args)]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++) - types[i] = args[i]->getType(); - - FunctionType *FT = FunctionType::get(call.getType(), types, false); - - Builder2.CreateCall( - called->getParent()->getOrInsertFunction("MPI_Reduce", FT), args, - BufferDefs); - } - - // 3. if root, set shadow(buffer) = intermediate [memcpy] - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - BasicBlock *rootBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_root", gutils->newFunc); - BasicBlock *nonrootBlock = gutils->addReverseBlock( - rootBlock, currentBlock->getName() + "_nonroot", gutils->newFunc); - BasicBlock *mergeBlock = gutils->addReverseBlock( - nonrootBlock, currentBlock->getName() + "_post", gutils->newFunc); - - Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock, - nonrootBlock); - - Builder2.SetInsertPoint(rootBlock); - - { - auto volatile_arg = ConstantInt::getFalse(call.getContext()); - Value *nargs[] = {shadow, buf, len_arg, volatile_arg}; - - Type *tys[] = {shadow->getType(), buf->getType(), len_arg->getType()}; - - auto memcpyF = getIntrinsicDeclaration(gutils->newFunc->getParent(), - Intrinsic::memcpy, tys); - - auto mem = - cast(Builder2.CreateCall(memcpyF, nargs, BufferDefs)); - mem->setCallingConv(memcpyF->getCallingConv()); - - // Free up the memory of the buffer - if (shouldFree()) { - CreateDealloc(Builder2, buf); - } - } - - Builder2.CreateBr(mergeBlock); - - Builder2.SetInsertPoint(nonrootBlock); - - // 3-e. else, set shadow(buffer) = 0 [memset] - auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0); - auto volatile_arg = ConstantInt::getFalse(call.getContext()); - Value *args[] = {shadow, val_arg, len_arg, volatile_arg}; - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - auto memset = cast(Builder2.CreateCall( - getIntrinsicDeclaration(gutils->newFunc->getParent(), - Intrinsic::memset, tys), - args, BufferDefs)); - memset->addParamAttr(0, Attribute::NonNull); - Builder2.CreateBr(mergeBlock); - - Builder2.SetInsertPoint(mergeBlock); - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // Approximate algo (for sum): -> if statement yet to be - // 1. malloc intermediate buffer - // 1.5 if root, set intermediate = diff(recvbuffer) - // 2. MPI_Bcast intermediate to all - // 3. if root, Zero diff(recvbuffer) [memset to 0] - // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy) - // 5. free intermediate buffer - - // int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, - // MPI_Datatype datatype, - // MPI_Op op, int root, MPI_Comm comm) - - if (funcName == "MPI_Reduce" || funcName == "PMPI_Reduce") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - // TODO insert a check for sum - - bool forwardMode = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError; - - IRBuilder<> Builder2 = - forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent()); - if (forwardMode) { - getForwardBuilder(Builder2); - } else { - getReverseBuilder(Builder2); - } - - // Get the operations from MPI_Receive - Value *orig_sendbuf = call.getOperand(0); - Value *orig_recvbuf = call.getOperand(1); - Value *orig_count = call.getOperand(2); - Value *orig_datatype = call.getOperand(3); - Value *orig_op = call.getOperand(4); - Value *orig_root = call.getOperand(5); - Value *orig_comm = call.getOperand(6); - - bool isSum = false; - if (Constant *C = dyn_cast(orig_op)) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_op_sum") { - isSum = true; - } - } - // MPICH - if (ConstantInt *CI = dyn_cast(C)) { - if (CI->getValue() == 1476395011) { - isSum = true; - } - } - } - if (!isSum) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << " call: " << call << "\n"; - ss << " unhandled mpi_reduce op: " << *orig_op << "\n"; - EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ); - return; - } - - Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2); - if (!forwardMode) - shadow_recvbuf = lookup(shadow_recvbuf, Builder2); - if (shadow_recvbuf->getType()->isIntegerTy()) - shadow_recvbuf = Builder2.CreateIntToPtr( - shadow_recvbuf, getInt8PtrTy(call.getContext())); - - Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2); - if (!forwardMode) - shadow_sendbuf = lookup(shadow_sendbuf, Builder2); - if (shadow_sendbuf->getType()->isIntegerTy()) - shadow_sendbuf = Builder2.CreateIntToPtr( - shadow_sendbuf, getInt8PtrTy(call.getContext())); - - // Need to preserve the shadow send/recv buffers. - auto BufferDefs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Shadow, ValueType::Primal, - ValueType::Primal, ValueType::Primal, ValueType::Primal, - ValueType::Primal}, - Builder2, /*lookup*/ !forwardMode); - - Value *count = gutils->getNewFromOriginal(orig_count); - if (!forwardMode) - count = lookup(count, Builder2); - - Value *datatype = gutils->getNewFromOriginal(orig_datatype); - if (!forwardMode) - datatype = lookup(datatype, Builder2); - - Value *op = gutils->getNewFromOriginal(orig_op); - if (!forwardMode) - op = lookup(op, Builder2); - - Value *root = gutils->getNewFromOriginal(orig_root); - if (!forwardMode) - root = lookup(root, Builder2); - - Value *comm = gutils->getNewFromOriginal(orig_comm); - if (!forwardMode) - comm = lookup(comm, Builder2); - - Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType()); - - if (forwardMode) { - Value *args[] = { - /*sendbuf*/ shadow_sendbuf, - /*recvbuf*/ shadow_recvbuf, - /*count*/ count, - /*datatype*/ datatype, - /*op*/ op, - /*root*/ root, - /*comm*/ comm, - }; - - auto Defs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Shadow, ValueType::Primal, - ValueType::Primal, ValueType::Primal, ValueType::Primal, - ValueType::Primal}, - Builder2, /*lookup*/ false); - - auto callval = call.getCalledOperand(); - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - - Value *tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType()); - - // Get the length for the allocation of the intermediate buffer - auto len_arg = Builder2.CreateZExtOrTrunc( - count, Type::getInt64Ty(call.getContext())); - len_arg = - Builder2.CreateMul(len_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - - // 1. Alloc intermediate buffer - Value *buf = - CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()), - len_arg, "mpireduce_malloccache"); - - // 1.5 if root, set intermediate = diff(recvbuffer) - { - - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - BasicBlock *rootBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_root", gutils->newFunc); - BasicBlock *mergeBlock = gutils->addReverseBlock( - rootBlock, currentBlock->getName() + "_post", gutils->newFunc); - - Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock, - mergeBlock); - - Builder2.SetInsertPoint(rootBlock); - - { - auto volatile_arg = ConstantInt::getFalse(call.getContext()); - Value *nargs[] = {buf, shadow_recvbuf, len_arg, volatile_arg}; - - Type *tys[] = {nargs[0]->getType(), nargs[1]->getType(), - len_arg->getType()}; - - auto memcpyF = getIntrinsicDeclaration(gutils->newFunc->getParent(), - Intrinsic::memcpy, tys); - - auto mem = - cast(Builder2.CreateCall(memcpyF, nargs, BufferDefs)); - mem->setCallingConv(memcpyF->getCallingConv()); - } - - Builder2.CreateBr(mergeBlock); - Builder2.SetInsertPoint(mergeBlock); - } - - // 2. MPI_Bcast intermediate to all - { - // int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int - // root, - // MPI_Comm comm ) - Value *args[] = { - /*buf*/ buf, - /*count*/ count, - /*datatype*/ datatype, - /*int root*/ root, - /*comm*/ comm, - }; - Type *types[sizeof(args) / sizeof(*args)]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++) - types[i] = args[i]->getType(); - - FunctionType *FT = FunctionType::get(call.getType(), types, false); - Builder2.CreateCall( - called->getParent()->getOrInsertFunction("MPI_Bcast", FT), args, - BufferDefs); - } - - // 3. if root, Zero diff(recvbuffer) [memset to 0] - { - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - BasicBlock *rootBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_root", gutils->newFunc); - BasicBlock *mergeBlock = gutils->addReverseBlock( - rootBlock, currentBlock->getName() + "_post", gutils->newFunc); - - Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock, - mergeBlock); - - Builder2.SetInsertPoint(rootBlock); - - auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0); - auto volatile_arg = ConstantInt::getFalse(call.getContext()); - Value *args[] = {shadow_recvbuf, val_arg, len_arg, volatile_arg}; - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - auto memset = cast(Builder2.CreateCall( - getIntrinsicDeclaration(gutils->newFunc->getParent(), - Intrinsic::memset, tys), - args, BufferDefs)); - memset->addParamAttr(0, Attribute::NonNull); - - Builder2.CreateBr(mergeBlock); - Builder2.SetInsertPoint(mergeBlock); - } - - // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy) - DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf, - len_arg, Builder2, BufferDefs); - - // Free up intermediate buffer - if (shouldFree()) { - CreateDealloc(Builder2, buf); - } - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // Approximate algo (for sum): -> if statement yet to be - // 1. malloc intermediate buffers - // 2. MPI_Allreduce (sum) of diff(recvbuffer) to intermediate - // 3. Zero diff(recvbuffer) [memset to 0] - // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy) - // 5. free intermediate buffer - - // int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, - // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) - - if (funcName == "MPI_Allreduce") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - // TODO insert a check for sum - - bool forwardMode = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError; - - IRBuilder<> Builder2 = - forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent()); - if (forwardMode) { - getForwardBuilder(Builder2); - } else { - getReverseBuilder(Builder2); - } - - // Get the operations from MPI_Receive - Value *orig_sendbuf = call.getOperand(0); - Value *orig_recvbuf = call.getOperand(1); - Value *orig_count = call.getOperand(2); - Value *orig_datatype = call.getOperand(3); - Value *orig_op = call.getOperand(4); - Value *orig_comm = call.getOperand(5); - - bool isSum = false; - if (Constant *C = dyn_cast(orig_op)) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_op_sum") { - isSum = true; - } - } - // MPICH - if (ConstantInt *CI = dyn_cast(C)) { - if (CI->getValue() == 1476395011) { - isSum = true; - } - } - } - if (!isSum) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << " call: " << call << "\n"; - ss << " unhandled mpi_allreduce op: " << *orig_op << "\n"; - EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ); - return; - } - - Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2); - if (!forwardMode) - shadow_recvbuf = lookup(shadow_recvbuf, Builder2); - if (shadow_recvbuf->getType()->isIntegerTy()) - shadow_recvbuf = Builder2.CreateIntToPtr( - shadow_recvbuf, getInt8PtrTy(call.getContext())); - - Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2); - if (!forwardMode) - shadow_sendbuf = lookup(shadow_sendbuf, Builder2); - if (shadow_sendbuf->getType()->isIntegerTy()) - shadow_sendbuf = Builder2.CreateIntToPtr( - shadow_sendbuf, getInt8PtrTy(call.getContext())); - - // Need to preserve the shadow send/recv buffers. - auto BufferDefs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Shadow, ValueType::Primal, - ValueType::Primal, ValueType::Primal, ValueType::Primal}, - Builder2, /*lookup*/ !forwardMode); - - Value *count = gutils->getNewFromOriginal(orig_count); - if (!forwardMode) - count = lookup(count, Builder2); - - Value *datatype = gutils->getNewFromOriginal(orig_datatype); - if (!forwardMode) - datatype = lookup(datatype, Builder2); - - Value *comm = gutils->getNewFromOriginal(orig_comm); - if (!forwardMode) - comm = lookup(comm, Builder2); - - Value *op = gutils->getNewFromOriginal(orig_op); - if (!forwardMode) - op = lookup(op, Builder2); - - if (forwardMode) { - Value *args[] = { - /*sendbuf*/ shadow_sendbuf, - /*recvbuf*/ shadow_recvbuf, - /*count*/ count, - /*datatype*/ datatype, - /*op*/ op, - /*comm*/ comm, - }; - - auto callval = call.getCalledOperand(); - Builder2.CreateCall(call.getFunctionType(), callval, args, BufferDefs); - - return; - } - - Value *tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType()); - - // Get the length for the allocation of the intermediate buffer - auto len_arg = Builder2.CreateZExtOrTrunc( - count, Type::getInt64Ty(call.getContext())); - len_arg = - Builder2.CreateMul(len_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - - // 1. Alloc intermediate buffer - Value *buf = - CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()), - len_arg, "mpireduce_malloccache"); - - // 2. MPI_Allreduce (sum) of diff(recvbuffer) to intermediate - { - // int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, - // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) - Value *args[] = { - /*sendbuf*/ shadow_recvbuf, - /*recvbuf*/ buf, - /*count*/ count, - /*datatype*/ datatype, - /*op*/ op, - /*comm*/ comm, - }; - Type *types[sizeof(args) / sizeof(*args)]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++) - types[i] = args[i]->getType(); - - FunctionType *FT = FunctionType::get(call.getType(), types, false); - Builder2.CreateCall( - called->getParent()->getOrInsertFunction("MPI_Allreduce", FT), args, - BufferDefs); - } - - // 3. Zero diff(recvbuffer) [memset to 0] - auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0); - auto volatile_arg = ConstantInt::getFalse(call.getContext()); - Value *args[] = {shadow_recvbuf, val_arg, len_arg, volatile_arg}; - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - auto memset = cast(Builder2.CreateCall( - getIntrinsicDeclaration(gutils->newFunc->getParent(), - Intrinsic::memset, tys), - args, BufferDefs)); - memset->addParamAttr(0, Attribute::NonNull); - - // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy) - DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf, - len_arg, Builder2, BufferDefs); - - // Free up intermediate buffer - if (shouldFree()) { - CreateDealloc(Builder2, buf); - } - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // Approximate algo (for sum): -> if statement yet to be - // 1. malloc intermediate buffer - // 2. Scatter diff(recvbuffer) to intermediate buffer - // 3. if root, Zero diff(recvbuffer) [memset to 0] - // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy) - // 5. free intermediate buffer - - // int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, - // void *recvbuf, int recvcount, MPI_Datatype recvtype, - // int root, MPI_Comm comm) - - if (funcName == "MPI_Gather") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - bool forwardMode = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError; - - IRBuilder<> Builder2 = - forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent()); - if (forwardMode) { - getForwardBuilder(Builder2); - } else { - getReverseBuilder(Builder2); - } - - Value *orig_sendbuf = call.getOperand(0); - Value *orig_sendcount = call.getOperand(1); - Value *orig_sendtype = call.getOperand(2); - Value *orig_recvbuf = call.getOperand(3); - Value *orig_recvcount = call.getOperand(4); - Value *orig_recvtype = call.getOperand(5); - Value *orig_root = call.getOperand(6); - Value *orig_comm = call.getOperand(7); - - Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2); - if (!forwardMode) - shadow_recvbuf = lookup(shadow_recvbuf, Builder2); - if (shadow_recvbuf->getType()->isIntegerTy()) - shadow_recvbuf = Builder2.CreateIntToPtr( - shadow_recvbuf, getInt8PtrTy(call.getContext())); - - Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2); - if (!forwardMode) - shadow_sendbuf = lookup(shadow_sendbuf, Builder2); - if (shadow_sendbuf->getType()->isIntegerTy()) - shadow_sendbuf = Builder2.CreateIntToPtr( - shadow_sendbuf, getInt8PtrTy(call.getContext())); - - Value *recvcount = gutils->getNewFromOriginal(orig_recvcount); - if (!forwardMode) - recvcount = lookup(recvcount, Builder2); - - Value *recvtype = gutils->getNewFromOriginal(orig_recvtype); - if (!forwardMode) - recvtype = lookup(recvtype, Builder2); - - Value *sendcount = gutils->getNewFromOriginal(orig_sendcount); - if (!sendcount) - sendcount = lookup(sendcount, Builder2); - - Value *sendtype = gutils->getNewFromOriginal(orig_sendtype); - if (!forwardMode) - sendtype = lookup(sendtype, Builder2); - - Value *root = gutils->getNewFromOriginal(orig_root); - if (!forwardMode) - root = lookup(root, Builder2); - - Value *comm = gutils->getNewFromOriginal(orig_comm); - if (!forwardMode) - comm = lookup(comm, Builder2); - - Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType()); - Value *tysize = MPI_TYPE_SIZE(sendtype, Builder2, call.getType()); - - if (forwardMode) { - Value *args[] = { - /*sendbuf*/ shadow_sendbuf, - /*sendcount*/ sendcount, - /*sendtype*/ sendtype, - /*recvbuf*/ shadow_recvbuf, - /*recvcount*/ recvcount, - /*recvtype*/ recvtype, - /*root*/ root, - /*comm*/ comm, - }; - - auto Defs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal, ValueType::Primal}, - Builder2, /*lookup*/ false); - - auto callval = call.getCalledOperand(); - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - - // Get the length for the allocation of the intermediate buffer - auto sendlen_arg = Builder2.CreateZExtOrTrunc( - sendcount, Type::getInt64Ty(call.getContext())); - sendlen_arg = - Builder2.CreateMul(sendlen_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - - // Need to preserve the shadow send/recv buffers. - auto BufferDefs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal, ValueType::Primal}, - Builder2, /*lookup*/ true); - - // 1. Alloc intermediate buffer - Value *buf = - CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()), - sendlen_arg, "mpireduce_malloccache"); - - // 2. Scatter diff(recvbuffer) to intermediate buffer - { - // int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype - // sendtype, - // void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, - // MPI_Comm comm) - Value *args[] = { - /*sendbuf*/ shadow_recvbuf, - /*sendcount*/ recvcount, - /*sendtype*/ recvtype, - /*recvbuf*/ buf, - /*recvcount*/ sendcount, - /*recvtype*/ sendtype, - /*op*/ root, - /*comm*/ comm, - }; - Type *types[sizeof(args) / sizeof(*args)]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++) - types[i] = args[i]->getType(); - - FunctionType *FT = FunctionType::get(call.getType(), types, false); - Builder2.CreateCall( - called->getParent()->getOrInsertFunction("MPI_Scatter", FT), args, - BufferDefs); - } - - // 3. if root, Zero diff(recvbuffer) [memset to 0] - { - - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - BasicBlock *rootBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_root", gutils->newFunc); - BasicBlock *mergeBlock = gutils->addReverseBlock( - rootBlock, currentBlock->getName() + "_post", gutils->newFunc); - - Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock, - mergeBlock); - - Builder2.SetInsertPoint(rootBlock); - auto recvlen_arg = Builder2.CreateZExtOrTrunc( - recvcount, Type::getInt64Ty(call.getContext())); - recvlen_arg = - Builder2.CreateMul(recvlen_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - recvlen_arg = Builder2.CreateMul( - recvlen_arg, - Builder2.CreateZExtOrTrunc( - MPI_COMM_SIZE(comm, Builder2, root->getType()), - Type::getInt64Ty(call.getContext())), - "", true, true); - - auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0); - auto volatile_arg = ConstantInt::getFalse(call.getContext()); - Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg}; - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - auto memset = cast(Builder2.CreateCall( - getIntrinsicDeclaration(gutils->newFunc->getParent(), - Intrinsic::memset, tys), - args, BufferDefs)); - memset->addParamAttr(0, Attribute::NonNull); - - Builder2.CreateBr(mergeBlock); - Builder2.SetInsertPoint(mergeBlock); - } - - // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy) - DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf, - sendlen_arg, Builder2, BufferDefs); - - // Free up intermediate buffer - if (shouldFree()) { - CreateDealloc(Builder2, buf); - } - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // Approximate algo (for sum): -> if statement yet to be - // 1. if root, malloc intermediate buffer, else undef - // 2. Gather diff(recvbuffer) to intermediate buffer - // 3. Zero diff(recvbuffer) [memset to 0] - // 4. if root, diff(sendbuffer) += intermediate buffer (diffmemcopy) - // 5. if root, free intermediate buffer - - // int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype - // sendtype, - // void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, - // MPI_Comm comm) - if (funcName == "MPI_Scatter") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - bool forwardMode = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError; - - IRBuilder<> Builder2 = - forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent()); - if (forwardMode) { - getForwardBuilder(Builder2); - } else { - getReverseBuilder(Builder2); - } - - Value *orig_sendbuf = call.getOperand(0); - Value *orig_sendcount = call.getOperand(1); - Value *orig_sendtype = call.getOperand(2); - Value *orig_recvbuf = call.getOperand(3); - Value *orig_recvcount = call.getOperand(4); - Value *orig_recvtype = call.getOperand(5); - Value *orig_root = call.getOperand(6); - Value *orig_comm = call.getOperand(7); - - Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2); - if (!forwardMode) - shadow_recvbuf = lookup(shadow_recvbuf, Builder2); - if (shadow_recvbuf->getType()->isIntegerTy()) - shadow_recvbuf = Builder2.CreateIntToPtr( - shadow_recvbuf, getInt8PtrTy(call.getContext())); - - Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2); - if (!forwardMode) - shadow_sendbuf = lookup(shadow_sendbuf, Builder2); - if (shadow_sendbuf->getType()->isIntegerTy()) - shadow_sendbuf = Builder2.CreateIntToPtr( - shadow_sendbuf, getInt8PtrTy(call.getContext())); - - Value *recvcount = gutils->getNewFromOriginal(orig_recvcount); - if (!forwardMode) - recvcount = lookup(recvcount, Builder2); - - Value *recvtype = gutils->getNewFromOriginal(orig_recvtype); - if (!forwardMode) - recvtype = lookup(recvtype, Builder2); - - Value *sendcount = gutils->getNewFromOriginal(orig_sendcount); - if (!forwardMode) - sendcount = lookup(sendcount, Builder2); - - Value *sendtype = gutils->getNewFromOriginal(orig_sendtype); - if (!forwardMode) - sendtype = lookup(sendtype, Builder2); - - Value *root = gutils->getNewFromOriginal(orig_root); - if (!forwardMode) - root = lookup(root, Builder2); - - Value *comm = gutils->getNewFromOriginal(orig_comm); - if (!forwardMode) - comm = lookup(comm, Builder2); - - Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType()); - Value *tysize = MPI_TYPE_SIZE(sendtype, Builder2, call.getType()); - - if (forwardMode) { - Value *args[] = { - /*sendbuf*/ shadow_sendbuf, - /*sendcount*/ sendcount, - /*sendtype*/ sendtype, - /*recvbuf*/ shadow_recvbuf, - /*recvcount*/ recvcount, - /*recvtype*/ recvtype, - /*root*/ root, - /*comm*/ comm, - }; - - auto Defs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal, ValueType::Primal}, - Builder2, /*lookup*/ false); - - auto callval = call.getCalledOperand(); - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - // Get the length for the allocation of the intermediate buffer - auto recvlen_arg = Builder2.CreateZExtOrTrunc( - recvcount, Type::getInt64Ty(call.getContext())); - recvlen_arg = - Builder2.CreateMul(recvlen_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - - // Need to preserve the shadow send/recv buffers. - auto BufferDefs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal, ValueType::Primal}, - Builder2, /*lookup*/ true); - - // 1. if root, malloc intermediate buffer, else undef - PHINode *buf; - PHINode *sendlen_phi; - - { - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - BasicBlock *rootBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_root", gutils->newFunc); - BasicBlock *mergeBlock = gutils->addReverseBlock( - rootBlock, currentBlock->getName() + "_post", gutils->newFunc); - - Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock, - mergeBlock); - - Builder2.SetInsertPoint(rootBlock); - - auto sendlen_arg = Builder2.CreateZExtOrTrunc( - sendcount, Type::getInt64Ty(call.getContext())); - sendlen_arg = - Builder2.CreateMul(sendlen_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - sendlen_arg = Builder2.CreateMul( - sendlen_arg, - Builder2.CreateZExtOrTrunc( - MPI_COMM_SIZE(comm, Builder2, root->getType()), - Type::getInt64Ty(call.getContext())), - "", true, true); - - Value *rootbuf = - CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()), - sendlen_arg, "mpireduce_malloccache"); - - Builder2.CreateBr(mergeBlock); - - Builder2.SetInsertPoint(mergeBlock); - - buf = Builder2.CreatePHI(rootbuf->getType(), 2); - buf->addIncoming(rootbuf, rootBlock); - buf->addIncoming(UndefValue::get(buf->getType()), currentBlock); - - sendlen_phi = Builder2.CreatePHI(sendlen_arg->getType(), 2); - sendlen_phi->addIncoming(sendlen_arg, rootBlock); - sendlen_phi->addIncoming(UndefValue::get(sendlen_arg->getType()), - currentBlock); - } - - // 2. Gather diff(recvbuffer) to intermediate buffer - { - // int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype - // sendtype, - // void *recvbuf, int recvcount, MPI_Datatype recvtype, - // int root, MPI_Comm comm) - Value *args[] = { - /*sendbuf*/ shadow_recvbuf, - /*sendcount*/ recvcount, - /*sendtype*/ recvtype, - /*recvbuf*/ buf, - /*recvcount*/ sendcount, - /*recvtype*/ sendtype, - /*root*/ root, - /*comm*/ comm, - }; - Type *types[sizeof(args) / sizeof(*args)]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++) - types[i] = args[i]->getType(); - - FunctionType *FT = FunctionType::get(call.getType(), types, false); - Builder2.CreateCall( - called->getParent()->getOrInsertFunction("MPI_Gather", FT), args, - BufferDefs); - } - - // 3. Zero diff(recvbuffer) [memset to 0] - { - auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0); - auto volatile_arg = ConstantInt::getFalse(call.getContext()); - Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg}; - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - auto memset = cast(Builder2.CreateCall( - getIntrinsicDeclaration(gutils->newFunc->getParent(), - Intrinsic::memset, tys), - args, BufferDefs)); - memset->addParamAttr(0, Attribute::NonNull); - } - - // 4. if root, diff(sendbuffer) += intermediate buffer (diffmemcopy) - // 5. if root, free intermediate buffer - - { - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - BasicBlock *rootBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_root", gutils->newFunc); - BasicBlock *mergeBlock = gutils->addReverseBlock( - rootBlock, currentBlock->getName() + "_post", gutils->newFunc); - - Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock, - mergeBlock); - - Builder2.SetInsertPoint(rootBlock); - - // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy) - DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf, - sendlen_phi, Builder2, BufferDefs); - - // Free up intermediate buffer - if (shouldFree()) { - CreateDealloc(Builder2, buf); - } - - Builder2.CreateBr(mergeBlock); - Builder2.SetInsertPoint(mergeBlock); - } - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // Approximate algo (for sum): -> if statement yet to be - // 1. malloc intermediate buffer - // 2. reduce diff(recvbuffer) then scatter to corresponding input node's - // intermediate buffer - // 3. Zero diff(recvbuffer) [memset to 0] - // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy) - // 5. free intermediate buffer - - // int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype - // sendtype, - // void *recvbuf, int recvcount, MPI_Datatype recvtype, - // MPI_Comm comm) - - if (funcName == "MPI_Allgather") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - bool forwardMode = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError; - - IRBuilder<> Builder2 = - forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent()); - if (forwardMode) { - getForwardBuilder(Builder2); - } else { - getReverseBuilder(Builder2); - } - - Value *orig_sendbuf = call.getOperand(0); - Value *orig_sendcount = call.getOperand(1); - Value *orig_sendtype = call.getOperand(2); - Value *orig_recvbuf = call.getOperand(3); - Value *orig_recvcount = call.getOperand(4); - Value *orig_recvtype = call.getOperand(5); - Value *orig_comm = call.getOperand(6); - - Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2); - if (!forwardMode) - shadow_recvbuf = lookup(shadow_recvbuf, Builder2); - - if (shadow_recvbuf->getType()->isIntegerTy()) - shadow_recvbuf = Builder2.CreateIntToPtr( - shadow_recvbuf, getInt8PtrTy(call.getContext())); - - Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2); - if (!forwardMode) - shadow_sendbuf = lookup(shadow_sendbuf, Builder2); - - if (shadow_sendbuf->getType()->isIntegerTy()) - shadow_sendbuf = Builder2.CreateIntToPtr( - shadow_sendbuf, getInt8PtrTy(call.getContext())); - - Value *recvcount = gutils->getNewFromOriginal(orig_recvcount); - if (!forwardMode) - recvcount = lookup(recvcount, Builder2); - - Value *recvtype = gutils->getNewFromOriginal(orig_recvtype); - if (!forwardMode) - recvtype = lookup(recvtype, Builder2); - - Value *sendcount = gutils->getNewFromOriginal(orig_sendcount); - if (!forwardMode) - sendcount = lookup(sendcount, Builder2); - - Value *sendtype = gutils->getNewFromOriginal(orig_sendtype); - if (!forwardMode) - sendtype = lookup(sendtype, Builder2); - - Value *comm = gutils->getNewFromOriginal(orig_comm); - if (!forwardMode) - comm = lookup(comm, Builder2); - - Value *tysize = MPI_TYPE_SIZE(sendtype, Builder2, call.getType()); - - if (forwardMode) { - Value *args[] = { - /*sendbuf*/ shadow_sendbuf, - /*sendcount*/ sendcount, - /*sendtype*/ sendtype, - /*recvbuf*/ shadow_recvbuf, - /*recvcount*/ recvcount, - /*recvtype*/ recvtype, - /*comm*/ comm, - }; - - auto Defs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal}, - Builder2, /*lookup*/ false); - - auto callval = call.getCalledOperand(); - Builder2.CreateCall(call.getFunctionType(), callval, args, Defs); - return; - } - // Get the length for the allocation of the intermediate buffer - auto sendlen_arg = Builder2.CreateZExtOrTrunc( - sendcount, Type::getInt64Ty(call.getContext())); - sendlen_arg = - Builder2.CreateMul(sendlen_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - - // Need to preserve the shadow send/recv buffers. - auto BufferDefs = gutils->getInvertedBundles( - &call, - {ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal}, - Builder2, /*lookup*/ true); - - // 1. Alloc intermediate buffer - Value *buf = - CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()), - sendlen_arg, "mpireduce_malloccache"); - - ConcreteType CT = TR.firstPointer(1, orig_sendbuf, &call); - auto MPI_OP_type = getInt8PtrTy(call.getContext()); - Type *MPI_OP_Ptr_type = PointerType::getUnqual(MPI_OP_type); - - // 2. reduce diff(recvbuffer) then scatter to corresponding input node's - // intermediate buffer - { - // int MPI_Reduce_scatter_block(const void* send_buffer, - // void* receive_buffer, - // int count, - // MPI_Datatype datatype, - // MPI_Op operation, - // MPI_Comm communicator); - Value *args[] = { - /*sendbuf*/ shadow_recvbuf, - /*recvbuf*/ buf, - /*recvcount*/ sendcount, - /*recvtype*/ sendtype, - /*op (MPI_SUM)*/ - getOrInsertOpFloatSum(*gutils->newFunc->getParent(), - MPI_OP_Ptr_type, MPI_OP_type, CT, - call.getType(), Builder2), - /*comm*/ comm, - }; - Type *types[sizeof(args) / sizeof(*args)]; - for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++) - types[i] = args[i]->getType(); - - FunctionType *FT = FunctionType::get(call.getType(), types, false); - Builder2.CreateCall(called->getParent()->getOrInsertFunction( - "MPI_Reduce_scatter_block", FT), - args, BufferDefs); - } - - // 3. zero diff(recvbuffer) [memset to 0] - { - auto recvlen_arg = Builder2.CreateZExtOrTrunc( - recvcount, Type::getInt64Ty(call.getContext())); - recvlen_arg = - Builder2.CreateMul(recvlen_arg, - Builder2.CreateZExtOrTrunc( - tysize, Type::getInt64Ty(call.getContext())), - "", true, true); - recvlen_arg = Builder2.CreateMul( - recvlen_arg, - Builder2.CreateZExtOrTrunc( - MPI_COMM_SIZE(comm, Builder2, call.getType()), - Type::getInt64Ty(call.getContext())), - "", true, true); - auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0); - auto volatile_arg = ConstantInt::getFalse(call.getContext()); - Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg}; - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - auto memset = cast(Builder2.CreateCall( - getIntrinsicDeclaration(gutils->newFunc->getParent(), - Intrinsic::memset, tys), - args, BufferDefs)); - memset->addParamAttr(0, Attribute::NonNull); - } - - // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy) - DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf, - sendlen_arg, Builder2, BufferDefs); - - // Free up intermediate buffer - if (shouldFree()) { - CreateDealloc(Builder2, buf); - } - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // Adjoint of barrier is to place a barrier at the corresponding - // location in the reverse. - if (funcName == "MPI_Barrier") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - auto callval = call.getCalledOperand(); - Value *args[] = { - lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2)}; - Builder2.CreateCall(call.getFunctionType(), callval, args); - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // Remove free's in forward pass so the comm can be used in the reverse - // pass - if (funcName == "MPI_Comm_free" || funcName == "MPI_Comm_disconnect") { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - // Adjoint of MPI_Comm_split / MPI_Graph_create (which allocates a comm in a - // pointer) is to free the created comm at the corresponding place in the - // reverse pass - auto commFound = MPIInactiveCommAllocators.find(funcName); - if (commFound != MPIInactiveCommAllocators.end()) { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - Value *args[] = {lookup(call.getOperand(commFound->second), Builder2)}; - Type *types[] = {args[0]->getType()}; - - FunctionType *FT = FunctionType::get(call.getType(), types, false); - Builder2.CreateCall( - called->getParent()->getOrInsertFunction("MPI_Comm_free", FT), args); - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return; - } - - llvm::errs() << *gutils->oldFunc->getParent() << "\n"; - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << call << "\n"; - llvm::errs() << called << "\n"; - llvm_unreachable("Unhandled MPI FUNCTION"); -} - -bool AdjointGenerator::handleKnownCallDerivatives( - CallInst &call, Function *called, StringRef funcName, - bool subsequent_calls_may_write, const std::vector &overwritten_args, - CallInst *const newCall) { - bool subretused = false; - bool shadowReturnUsed = false; - DIFFE_TYPE subretType = - gutils->getReturnDiffeType(&call, &subretused, &shadowReturnUsed); - - IRBuilder<> BuilderZ(newCall); - BuilderZ.setFastMathFlags(getFast()); - - if (Mode != DerivativeMode::ReverseModePrimal && called) { - if (funcName == "__kmpc_for_static_init_4" || - funcName == "__kmpc_for_static_init_4u" || - funcName == "__kmpc_for_static_init_8" || - funcName == "__kmpc_for_static_init_8u") { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - auto fini = called->getParent()->getFunction("__kmpc_for_static_fini"); - assert(fini); - Value *args[] = { - lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2), - lookup(gutils->getNewFromOriginal(call.getArgOperand(1)), Builder2)}; - auto fcall = Builder2.CreateCall(fini->getFunctionType(), fini, args); - fcall->setCallingConv(fini->getCallingConv()); - return true; - } - } - - if ((startsWith(funcName, "MPI_") || startsWith(funcName, "PMPI_")) && - (!gutils->isConstantInstruction(&call) || funcName == "MPI_Barrier" || - funcName == "MPI_Comm_free" || funcName == "MPI_Comm_disconnect" || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end())) { - handleMPI(call, called, funcName); - return true; - } - - if (auto blas = extractBLAS(funcName)) { - if (handleBLAS(call, called, *blas, overwritten_args)) - return true; - } - - if (funcName == "printf" || funcName == "puts" || - startsWith(funcName, "_ZN3std2io5stdio6_print") || - startsWith(funcName, "_ZN4core3fmt")) { - if (Mode == DerivativeMode::ReverseModeGradient) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } - return true; - } - if (called && (called->getName().contains("__enzyme_float") || - called->getName().contains("__enzyme_double") || - called->getName().contains("__enzyme_integer") || - called->getName().contains("__enzyme_pointer"))) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - - // Handle lgamma, safe to recompute so no store/change to forward - if (called) { - if (funcName == "__kmpc_for_static_init_4" || - funcName == "__kmpc_for_static_init_4u" || - funcName == "__kmpc_for_static_init_8" || - funcName == "__kmpc_for_static_init_8u") { - if (Mode != DerivativeMode::ReverseModePrimal) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - auto fini = called->getParent()->getFunction("__kmpc_for_static_fini"); - assert(fini); - Value *args[] = { - lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2), - lookup(gutils->getNewFromOriginal(call.getArgOperand(1)), - Builder2)}; - auto fcall = Builder2.CreateCall(fini->getFunctionType(), fini, args); - fcall->setCallingConv(fini->getCallingConv()); - } - return true; - } - if (funcName == "__kmpc_for_static_fini") { - if (Mode != DerivativeMode::ReverseModePrimal) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } - return true; - } - // TODO check - // Adjoint of barrier is to place a barrier at the corresponding - // location in the reverse. - if (funcName == "__kmpc_barrier") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - auto callval = call.getCalledOperand(); - Value *args[] = { - lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2), - lookup(gutils->getNewFromOriginal(call.getOperand(1)), Builder2)}; - Builder2.CreateCall(call.getFunctionType(), callval, args); - } - return true; - } - if (funcName == "__kmpc_critical") { - if (Mode != DerivativeMode::ReverseModePrimal) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - auto crit2 = called->getParent()->getFunction("__kmpc_end_critical"); - assert(crit2); - Value *args[] = { - lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2), - lookup(gutils->getNewFromOriginal(call.getArgOperand(1)), Builder2), - lookup(gutils->getNewFromOriginal(call.getArgOperand(2)), - Builder2)}; - auto fcall = Builder2.CreateCall(crit2->getFunctionType(), crit2, args); - fcall->setCallingConv(crit2->getCallingConv()); - } - return true; - } - if (funcName == "__kmpc_end_critical") { - if (Mode != DerivativeMode::ReverseModePrimal) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - auto crit2 = called->getParent()->getFunction("__kmpc_critical"); - assert(crit2); - Value *args[] = { - lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2), - lookup(gutils->getNewFromOriginal(call.getArgOperand(1)), Builder2), - lookup(gutils->getNewFromOriginal(call.getArgOperand(2)), - Builder2)}; - auto fcall = Builder2.CreateCall(crit2->getFunctionType(), crit2, args); - fcall->setCallingConv(crit2->getCallingConv()); - } - return true; - } - - if (startsWith(funcName, "__kmpc") && - funcName != "__kmpc_global_thread_num") { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << call << "\n"; - assert(0 && "unhandled openmp function"); - llvm_unreachable("unhandled openmp function"); - } - - auto mod = call.getParent()->getParent()->getParent(); -#include "CallDerivatives.inc" - - if (funcName == "llvm.julia.gc_preserve_end") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - - auto begin_call = cast(call.getOperand(0)); - - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - SmallVector args; - for (auto &arg : begin_call->args()) { - bool primalUsed = false; - bool shadowUsed = false; - gutils->getReturnDiffeType(arg, &primalUsed, &shadowUsed); - - if (primalUsed) - args.push_back( - gutils->lookupM(gutils->getNewFromOriginal(arg), Builder2)); - - if (!gutils->isConstantValue(arg) && shadowUsed) { - Value *ptrshadow = gutils->lookupM( - gutils->invertPointerM(arg, BuilderZ), Builder2); - if (gutils->getWidth() == 1) - args.push_back(ptrshadow); - else - for (size_t i = 0; i < gutils->getWidth(); ++i) - args.push_back(gutils->extractMeta(Builder2, ptrshadow, i)); - } - } - - auto newp = Builder2.CreateCall( - called->getParent()->getOrInsertFunction( - "llvm.julia.gc_preserve_begin", - FunctionType::get(Type::getTokenTy(call.getContext()), - ArrayRef(), true)), - args); - auto ifound = gutils->invertedPointers.find(begin_call); - assert(ifound != gutils->invertedPointers.end()); - auto placeholder = cast(&*ifound->second); - gutils->invertedPointers.erase(ifound); - gutils->invertedPointers.insert(std::make_pair( - (const Value *)begin_call, InvertedPointerVH(gutils, newp))); - - gutils->replaceAWithB(placeholder, newp); - gutils->erase(placeholder); - } - return true; - } - if (funcName == "llvm.julia.gc_preserve_begin") { - SmallVector args; - for (auto &arg : call.args()) { - bool primalUsed = false; - bool shadowUsed = false; - gutils->getReturnDiffeType(arg, &primalUsed, &shadowUsed); - - if (primalUsed) - args.push_back(gutils->getNewFromOriginal(arg)); - - if (!gutils->isConstantValue(arg) && shadowUsed) { - Value *ptrshadow = gutils->invertPointerM(arg, BuilderZ); - if (gutils->getWidth() == 1) - args.push_back(ptrshadow); - else - for (size_t i = 0; i < gutils->getWidth(); ++i) - args.push_back(gutils->extractMeta(BuilderZ, ptrshadow, i)); - } - } - - auto newp = BuilderZ.CreateCall(called, args); - auto oldp = gutils->getNewFromOriginal(&call); - gutils->replaceAWithB(oldp, newp); - gutils->erase(oldp); - - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - auto ifound = gutils->invertedPointers.find(&call); - assert(ifound != gutils->invertedPointers.end()); - auto placeholder = cast(&*ifound->second); - Builder2.CreateCall( - called->getParent()->getOrInsertFunction( - "llvm.julia.gc_preserve_end", - FunctionType::get(Builder2.getVoidTy(), call.getType(), false)), - placeholder); - } - return true; - } - - /* - * int gsl_sf_legendre_array_e(const gsl_sf_legendre_t norm, - const size_t lmax, - const double x, - const double csphase, - double result_array[]); - */ - // d L(n, x) / dx = L(n,x) * x * (n-1) + 1 - if (funcName == "gsl_sf_legendre_array_e") { - if (gutils->isConstantValue(call.getArgOperand(4))) { - eraseIfUnused(call); - return true; - } - if (Mode == DerivativeMode::ReverseModePrimal) { - eraseIfUnused(call); - return true; - } - if (Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModeGradient) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - ValueType BundleTypes[5] = {ValueType::None, ValueType::None, - ValueType::None, ValueType::None, - ValueType::Shadow}; - auto Defs = gutils->getInvertedBundles(&call, BundleTypes, Builder2, - /*lookup*/ true); - - Type *types[6] = { - call.getOperand(0)->getType(), call.getOperand(1)->getType(), - call.getOperand(2)->getType(), call.getOperand(3)->getType(), - call.getOperand(4)->getType(), call.getOperand(4)->getType(), - }; - FunctionType *FT = FunctionType::get(call.getType(), types, false); - auto F = called->getParent()->getOrInsertFunction( - "gsl_sf_legendre_deriv_array_e", FT); - - llvm::Value *args[6] = { - gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(0)), - Builder2), - gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(1)), - Builder2), - gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(2)), - Builder2), - gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(3)), - Builder2), - nullptr, - nullptr}; - - Type *typesS[] = {args[1]->getType()}; - FunctionType *FTS = - FunctionType::get(args[1]->getType(), typesS, false); - auto FS = called->getParent()->getOrInsertFunction( - "gsl_sf_legendre_array_n", FTS); - Value *alSize = Builder2.CreateCall(FS, args[1]); - Value *tmp = CreateAllocation(Builder2, types[2], alSize); - Value *dtmp = CreateAllocation(Builder2, types[2], alSize); - Builder2.CreateLifetimeStart(tmp); - Builder2.CreateLifetimeStart(dtmp); - - args[4] = Builder2.CreateBitCast(tmp, types[4]); - args[5] = Builder2.CreateBitCast(dtmp, types[5]); - - Builder2.CreateCall(F, args, Defs); - Builder2.CreateLifetimeEnd(tmp); - CreateDealloc(Builder2, tmp); - - BasicBlock *currentBlock = Builder2.GetInsertBlock(); - - BasicBlock *loopBlock = gutils->addReverseBlock( - currentBlock, currentBlock->getName() + "_loop"); - BasicBlock *endBlock = - gutils->addReverseBlock(loopBlock, currentBlock->getName() + "_end", - /*fork*/ true, /*push*/ false); - - Builder2.CreateCondBr( - Builder2.CreateICmpEQ(args[1], Constant::getNullValue(types[1])), - endBlock, loopBlock); - Builder2.SetInsertPoint(loopBlock); - - auto idx = Builder2.CreatePHI(types[1], 2); - idx->addIncoming(ConstantInt::get(types[1], 0, false), currentBlock); - - auto acc_idx = Builder2.CreatePHI(types[2], 2); - - Value *inc = Builder2.CreateAdd( - idx, ConstantInt::get(types[1], 1, false), "", true, true); - idx->addIncoming(inc, loopBlock); - acc_idx->addIncoming(Constant::getNullValue(types[2]), currentBlock); - - Value *idxs[] = {idx}; - Value *dtmp_idx = Builder2.CreateInBoundsGEP(types[2], dtmp, idxs); - Value *d_req = Builder2.CreateInBoundsGEP( - types[2], - Builder2.CreatePointerCast( - gutils->invertPointerM(call.getOperand(4), Builder2), - PointerType::getUnqual(types[2])), - idxs); - - auto l0 = Builder2.CreateLoad(types[2], dtmp_idx); - auto l1 = Builder2.CreateLoad(types[2], d_req); - auto acc = Builder2.CreateFAdd(acc_idx, Builder2.CreateFMul(l0, l1)); - Builder2.CreateStore(Constant::getNullValue(types[2]), d_req); - - acc_idx->addIncoming(acc, loopBlock); - - Builder2.CreateCondBr(Builder2.CreateICmpEQ(inc, args[1]), endBlock, - loopBlock); - - Builder2.SetInsertPoint(endBlock); - { - auto found = gutils->reverseBlockToPrimal.find(endBlock); - assert(found != gutils->reverseBlockToPrimal.end()); - SmallVector &vec = - gutils->reverseBlocks[found->second]; - assert(vec.size()); - vec.push_back(endBlock); - } - - auto fin_idx = Builder2.CreatePHI(types[2], 2); - fin_idx->addIncoming(Constant::getNullValue(types[2]), currentBlock); - fin_idx->addIncoming(acc, loopBlock); - - Builder2.CreateLifetimeEnd(dtmp); - CreateDealloc(Builder2, dtmp); - - ((DiffeGradientUtils *)gutils) - ->addToDiffe(call.getOperand(2), fin_idx, Builder2, types[2]); - - return true; - } - } - - // Functions that only modify pointers and don't allocate memory, - // needs to be run on shadow in primal - if (funcName == "_ZSt29_Rb_tree_insert_and_rebalancebPSt18_Rb_tree_" - "node_baseS0_RS_") { - if (Mode == DerivativeMode::ReverseModeGradient) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - if (gutils->isConstantValue(call.getArgOperand(3))) - return true; - SmallVector args; - for (auto &arg : call.args()) { - if (gutils->isConstantValue(arg)) - args.push_back(gutils->getNewFromOriginal(arg)); - else - args.push_back(gutils->invertPointerM(arg, BuilderZ)); - } - BuilderZ.CreateCall(called, args); - return true; - } - - // Functions that initialize a shadow data structure (with no - // other arguments) needs to be run on shadow in primal. - if (funcName == "_ZNSt8ios_baseC2Ev" || funcName == "_ZNSt8ios_baseD2Ev" || - funcName == "_ZNSt6localeC1Ev" || funcName == "_ZNSt6localeD1Ev" || - funcName == "_ZNKSt5ctypeIcE13_M_widen_initEv") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - if (gutils->isConstantValue(call.getArgOperand(0))) - return true; - Value *args[] = {gutils->invertPointerM(call.getArgOperand(0), BuilderZ)}; - BuilderZ.CreateCall(called, args); - return true; - } - - if (funcName == "_ZNSt9basic_iosIcSt11char_traitsIcEE4initEPSt15basic_" - "streambufIcS1_E") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - if (gutils->isConstantValue(call.getArgOperand(0))) - return true; - Value *args[] = {gutils->invertPointerM(call.getArgOperand(0), BuilderZ), - gutils->invertPointerM(call.getArgOperand(1), BuilderZ)}; - BuilderZ.CreateCall(called, args); - return true; - } - - // if constant instruction and readonly (thus must be pointer return) - // and shadow return recomputable from shadow arguments. - if (funcName == "__dynamic_cast" || - funcName == "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base" || - funcName == "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base" || - funcName == "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base" || - funcName == "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base" || - funcName == "jl_ptr_to_array" || funcName == "jl_ptr_to_array_1d") { - bool shouldCache = false; - if (gutils->knownRecomputeHeuristic.find(&call) != - gutils->knownRecomputeHeuristic.end()) { - if (!gutils->knownRecomputeHeuristic[&call]) { - shouldCache = true; - } - } - ValueToValueMapTy empty; - bool lrc = gutils->legalRecompute(&call, empty, nullptr); - - if (!gutils->isConstantValue(&call)) { - auto ifound = gutils->invertedPointers.find(&call); - assert(ifound != gutils->invertedPointers.end()); - auto placeholder = cast(&*ifound->second); - - if (subretType == DIFFE_TYPE::DUP_ARG) { - Value *shadow = placeholder; - if (lrc || Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - if (gutils->isConstantValue(call.getArgOperand(0))) - shadow = gutils->getNewFromOriginal(&call); - else { - SmallVector args; - size_t i = 0; - for (auto &arg : call.args()) { - if (gutils->isConstantValue(arg) || - (funcName == "__dynamic_cast" && i > 0) || - (funcName == "jl_ptr_to_array_1d" && i != 1) || - (funcName == "jl_ptr_to_array" && i != 1)) - args.push_back(gutils->getNewFromOriginal(arg)); - else - args.push_back(gutils->invertPointerM(arg, BuilderZ)); - i++; - } - shadow = BuilderZ.CreateCall(called, args); - } - } - - bool needsReplacement = true; - if (!lrc && (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeGradient)) { - shadow = gutils->cacheForReverse( - BuilderZ, shadow, getIndex(&call, CacheType::Shadow, BuilderZ)); - if (Mode == DerivativeMode::ReverseModeGradient) - needsReplacement = false; - } - gutils->invertedPointers.erase((const Value *)&call); - gutils->invertedPointers.insert(std::make_pair( - (const Value *)&call, InvertedPointerVH(gutils, shadow))); - if (needsReplacement) { - assert(shadow != placeholder); - gutils->replaceAWithB(placeholder, shadow); - gutils->erase(placeholder); - } - } else { - gutils->invertedPointers.erase((const Value *)&call); - gutils->erase(placeholder); - } - } - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - eraseIfUnused(call); - assert(gutils->isConstantInstruction(&call)); - return true; - } - - if (!shouldCache && !lrc) { - std::map Seen; - for (auto pair : gutils->knownRecomputeHeuristic) - Seen[UsageKey(pair.first, QueryType::Primal)] = false; - bool primalNeededInReverse = - DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &call, Mode, Seen, oldUnreachable); - shouldCache = primalNeededInReverse; - } - - if (shouldCache) { - BuilderZ.SetInsertPoint(newCall->getNextNode()); - gutils->cacheForReverse(BuilderZ, newCall, - getIndex(&call, CacheType::Self, BuilderZ)); - } - eraseIfUnused(call); - assert(gutils->isConstantInstruction(&call)); - return true; - } - - if (called) { - if (funcName == "julia.write_barrier" || - funcName == "julia.write_barrier_binding") { - - std::map Seen; - for (auto pair : gutils->knownRecomputeHeuristic) - if (!pair.second) - Seen[UsageKey(pair.first, QueryType::Primal)] = false; - - bool backwardsShadow = false; - bool forwardsShadow = true; - for (auto pair : gutils->backwardsOnlyShadows) { - if (pair.second.stores.count(&call)) { - backwardsShadow = true; - forwardsShadow = pair.second.primalInitialize; - if (auto inst = dyn_cast(pair.first)) - if (!forwardsShadow && pair.second.LI && - pair.second.LI->contains(inst->getParent())) - backwardsShadow = false; - break; - } - } - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError || - (Mode == DerivativeMode::ReverseModeCombined && - (forwardsShadow || backwardsShadow)) || - (Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) || - (Mode == DerivativeMode::ReverseModeGradient && backwardsShadow)) { - IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call)); - for (int i = 0; i < gutils->getWidth(); i++) { - SmallVector iargs; - bool first = true; - for (auto &arg : call.args()) { - if (!gutils->isConstantValue(arg)) { - Value *ptrshadow = gutils->invertPointerM(arg, BuilderZ); - if (gutils->getWidth() > 1) { - ptrshadow = gutils->extractMeta(BuilderZ, ptrshadow, i); - } - iargs.push_back(ptrshadow); - } else { - if (first) - break; - } - first = false; - } - if (iargs.size()) { - BuilderZ.CreateCall(called, iargs); - } - } - } - - bool forceErase = false; - if (Mode == DerivativeMode::ReverseModeGradient) { - - // Since we won't redo the store in the reverse pass, do not - // force the write barrier. - forceErase = true; - for (const auto &pair : gutils->rematerializableAllocations) { - if (!pair.second.stores.count(&call)) - continue; - bool primalNeededInReverse = - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError - ? false - : DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, pair.first, Mode, Seen, - oldUnreachable); - - bool cacheWholeAllocation = - gutils->needsCacheWholeAllocation(pair.first); - if (cacheWholeAllocation) { - primalNeededInReverse = true; - } - - if (primalNeededInReverse && !cacheWholeAllocation) - // However, if we are rematerailizing the allocation and not - // inside the loop level rematerialization, we do still need the - // reverse passes ``fake primal'' store and therefore write - // barrier - if (!pair.second.LI || !pair.second.LI->contains(&call)) { - forceErase = false; - } - } - } - if (forceErase) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - else - eraseIfUnused(call); - - return true; - } - Intrinsic::ID ID = Intrinsic::not_intrinsic; - if (isMemFreeLibMFunction(funcName, &ID)) { - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(&call)) { - - if (gutils->knownRecomputeHeuristic.find(&call) != - gutils->knownRecomputeHeuristic.end()) { - if (!gutils->knownRecomputeHeuristic[&call]) { - gutils->cacheForReverse( - BuilderZ, newCall, - getIndex(&call, CacheType::Self, BuilderZ)); - } - } - eraseIfUnused(call); - return true; - } - - if (ID != Intrinsic::not_intrinsic) { - SmallVector orig_ops(call.getNumOperands()); - for (unsigned i = 0; i < call.getNumOperands(); ++i) { - orig_ops[i] = call.getOperand(i); - } - bool cached = handleAdjointForIntrinsic(ID, call, orig_ops); - if (!cached) { - if (gutils->knownRecomputeHeuristic.find(&call) != - gutils->knownRecomputeHeuristic.end()) { - if (!gutils->knownRecomputeHeuristic[&call]) { - gutils->cacheForReverse( - BuilderZ, newCall, - getIndex(&call, CacheType::Self, BuilderZ)); - } - } - } - eraseIfUnused(call); - return true; - } - } - } - } - if (auto assembly = dyn_cast(call.getCalledOperand())) { - if (assembly->getAsmString() == "maxpd $1, $0") { - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(&call)) { - - if (gutils->knownRecomputeHeuristic.find(&call) != - gutils->knownRecomputeHeuristic.end()) { - if (!gutils->knownRecomputeHeuristic[&call]) { - gutils->cacheForReverse(BuilderZ, newCall, - getIndex(&call, CacheType::Self, BuilderZ)); - } - } - eraseIfUnused(call); - return true; - } - - SmallVector orig_ops(call.getNumOperands()); - for (unsigned i = 0; i < call.getNumOperands(); ++i) { - orig_ops[i] = call.getOperand(i); - } - handleAdjointForIntrinsic(Intrinsic::maxnum, call, orig_ops); - if (gutils->knownRecomputeHeuristic.find(&call) != - gutils->knownRecomputeHeuristic.end()) { - if (!gutils->knownRecomputeHeuristic[&call]) { - gutils->cacheForReverse(BuilderZ, newCall, - getIndex(&call, CacheType::Self, BuilderZ)); - } - } - eraseIfUnused(call); - return true; - } - } - - if (funcName == "realloc") { - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - if (!gutils->isConstantValue(&call)) { - IRBuilder<> Builder2(&call); - getForwardBuilder(Builder2); - - auto dbgLoc = gutils->getNewFromOriginal(&call)->getDebugLoc(); - - auto rule = [&](Value *ip) { - ValueType BundleTypes[2] = {ValueType::Shadow, ValueType::Primal}; - - auto Defs = gutils->getInvertedBundles(&call, BundleTypes, Builder2, - /*lookup*/ false); - - llvm::Value *args[2] = { - ip, gutils->getNewFromOriginal(call.getOperand(1))}; - CallInst *CI = Builder2.CreateCall( - call.getFunctionType(), call.getCalledFunction(), args, Defs); - CI->setAttributes(call.getAttributes()); - CI->setCallingConv(call.getCallingConv()); - CI->setTailCallKind(call.getTailCallKind()); - CI->setDebugLoc(dbgLoc); - return CI; - }; - - Value *CI = applyChainRule( - call.getType(), Builder2, rule, - gutils->invertPointerM(call.getOperand(0), Builder2)); - - auto found = gutils->invertedPointers.find(&call); - PHINode *placeholder = cast(&*found->second); - - gutils->invertedPointers.erase(found); - gutils->replaceAWithB(placeholder, CI); - gutils->erase(placeholder); - gutils->invertedPointers.insert( - std::make_pair(&call, InvertedPointerVH(gutils, CI))); - } - eraseIfUnused(call); - return true; - } - } - - if (isAllocationFunction(funcName, gutils->TLI)) { - - bool constval = gutils->isConstantValue(&call); - - if (!constval) { - auto dbgLoc = gutils->getNewFromOriginal(&call)->getDebugLoc(); - auto found = gutils->invertedPointers.find(&call); - PHINode *placeholder = cast(&*found->second); - IRBuilder<> bb(placeholder); - - SmallVector args; - for (auto &arg : call.args()) { - args.push_back(gutils->getNewFromOriginal(arg)); - } - - if (Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ForwardModeSplit) { - - Value *anti = placeholder; - // If rematerializable allocations and split mode, we can - // simply elect to build the entire piece in the reverse - // since it should be possible to perform any shadow stores - // of pointers (from rematerializable property) and it does - // not escape the function scope (lest it not be - // rematerializable) so all input derivatives remain zero. - bool backwardsShadow = false; - bool forwardsShadow = true; - bool inLoop = false; - bool isAlloca = isa(&call); - { - auto found = gutils->backwardsOnlyShadows.find(&call); - if (found != gutils->backwardsOnlyShadows.end()) { - backwardsShadow = true; - forwardsShadow = found->second.primalInitialize; - // If in a loop context, maintain the same free behavior. - if (found->second.LI && - found->second.LI->contains(call.getParent())) - inLoop = true; - } - } - { - - if (!forwardsShadow) { - if (Mode == DerivativeMode::ReverseModePrimal) { - // Needs a stronger replacement check/assertion. - Value *replacement = getUndefinedValueForType( - *gutils->oldFunc->getParent(), placeholder->getType()); - gutils->replaceAWithB(placeholder, replacement); - gutils->invertedPointers.erase(found); - gutils->invertedPointers.insert(std::make_pair( - &call, InvertedPointerVH(gutils, replacement))); - gutils->erase(placeholder); - anti = nullptr; - goto endAnti; - } else if (inLoop) { - gutils->rematerializedPrimalOrShadowAllocations.push_back( - placeholder); - if (hasMetadata(&call, "enzyme_fromstack")) - isAlloca = true; - goto endAnti; - } - } - placeholder->setName(""); - if (shadowHandlers.find(funcName) != shadowHandlers.end()) { - bb.SetInsertPoint(placeholder); - - if (Mode == DerivativeMode::ReverseModeCombined || - (Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) || - (Mode == DerivativeMode::ReverseModeGradient && - backwardsShadow)) { - anti = applyChainRule(call.getType(), bb, [&]() { - return shadowHandlers[funcName](bb, &call, args, gutils); - }); - if (anti->getType() != placeholder->getType()) { - llvm::errs() << "orig: " << call << "\n"; - llvm::errs() << "placeholder: " << *placeholder << "\n"; - llvm::errs() << "anti: " << *anti << "\n"; - } - gutils->invertedPointers.erase(found); - bb.SetInsertPoint(placeholder); - - gutils->replaceAWithB(placeholder, anti); - gutils->erase(placeholder); - } - - if (auto inst = dyn_cast(anti)) - bb.SetInsertPoint(inst); - - if (!backwardsShadow) - anti = gutils->cacheForReverse( - bb, anti, getIndex(&call, CacheType::Shadow, BuilderZ)); - } else { - bool zeroed = false; - uint64_t idx = 0; - Value *prev = nullptr; - ; - auto rule = [&]() { - Value *anti = - bb.CreateCall(call.getFunctionType(), call.getCalledOperand(), - args, call.getName() + "'mi"); - cast(anti)->setAttributes(call.getAttributes()); - cast(anti)->setCallingConv(call.getCallingConv()); - cast(anti)->setTailCallKind(call.getTailCallKind()); - cast(anti)->setDebugLoc(dbgLoc); - - if (anti->getType()->isPointerTy()) { - cast(anti)->addAttributeAtIndex( - AttributeList::ReturnIndex, Attribute::NoAlias); - cast(anti)->addAttributeAtIndex( - AttributeList::ReturnIndex, Attribute::NonNull); - - if (funcName == "malloc" || funcName == "_Znwm" || - funcName == "??2@YAPAXI@Z" || - funcName == "??2@YAPEAX_K@Z") { - if (auto ci = dyn_cast(args[0])) { - unsigned derefBytes = ci->getLimitedValue(); - CallInst *cal = - cast(gutils->getNewFromOriginal(&call)); - cast(anti)->addDereferenceableRetAttr(derefBytes); - cal->addDereferenceableRetAttr(derefBytes); -#if !defined(FLANG) && !defined(ROCM) - AttrBuilder B(ci->getContext()); -#else - AttrBuilder B; -#endif - B.addDereferenceableOrNullAttr(derefBytes); - cast(anti)->setAttributes( - cast(anti)->getAttributes().addRetAttributes( - call.getContext(), B)); - cal->setAttributes(cal->getAttributes().addRetAttributes( - call.getContext(), B)); - cal->addAttributeAtIndex(AttributeList::ReturnIndex, - Attribute::NoAlias); - cal->addAttributeAtIndex(AttributeList::ReturnIndex, - Attribute::NonNull); - } - } - if (funcName == "julia.gc_alloc_obj" || - funcName == "jl_gc_alloc_typed" || - funcName == "ijl_gc_alloc_typed") { - if (EnzymeShadowAllocRewrite) { - bool used = unnecessaryInstructions.find(&call) == - unnecessaryInstructions.end(); - EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call), - idx, wrap(prev), used); - } - } - } - if (Mode == DerivativeMode::ReverseModeCombined || - (Mode == DerivativeMode::ReverseModePrimal && - forwardsShadow) || - (Mode == DerivativeMode::ReverseModeGradient && - backwardsShadow) || - (Mode == DerivativeMode::ForwardModeSplit && - backwardsShadow)) { - if (!inLoop) { - zeroKnownAllocation(bb, anti, args, funcName, gutils->TLI, - &call); - zeroed = true; - } - } - idx++; - prev = anti; - return anti; - }; - - anti = applyChainRule(call.getType(), bb, rule); - - gutils->invertedPointers.erase(found); - if (&*bb.GetInsertPoint() == placeholder) - bb.SetInsertPoint(placeholder->getNextNode()); - gutils->replaceAWithB(placeholder, anti); - gutils->erase(placeholder); - - if (!backwardsShadow) - anti = gutils->cacheForReverse( - bb, anti, getIndex(&call, CacheType::Shadow, BuilderZ)); - else { - if (auto MD = hasMetadata(&call, "enzyme_fromstack")) { - isAlloca = true; - bb.SetInsertPoint(cast(anti)); - Value *Size; - if (funcName == "malloc") - Size = args[0]; - else if (funcName == "julia.gc_alloc_obj" || - funcName == "jl_gc_alloc_typed" || - funcName == "ijl_gc_alloc_typed") - Size = args[1]; - else - llvm_unreachable("Unknown allocation to upgrade"); - - Type *elTy = Type::getInt8Ty(call.getContext()); - std::string name = ""; -#if LLVM_VERSION_MAJOR < 17 - if (call.getContext().supportsTypedPointers()) { - for (auto U : call.users()) { - if (hasMetadata(cast(U), "enzyme_caststack")) { - elTy = U->getType()->getPointerElementType(); - Value *tsize = ConstantInt::get( - Size->getType(), (gutils->newFunc->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(elTy) + - 7) / - 8); - Size = bb.CreateUDiv(Size, tsize, "", /*exact*/ true); - name = (U->getName() + "'ai").str(); - break; - } - } - } -#endif - auto rule = [&](Value *anti) { - bb.SetInsertPoint(cast(anti)); - Value *replacement = bb.CreateAlloca(elTy, Size, name); - if (name.size() == 0) - replacement->takeName(anti); - else - anti->setName(""); - auto Alignment = cast(cast( - MD->getOperand(0)) - ->getValue()) - ->getLimitedValue(); - if (Alignment) { - cast(replacement) - ->setAlignment(Align(Alignment)); - } -#if LLVM_VERSION_MAJOR < 17 - if (call.getContext().supportsTypedPointers()) { - if (anti->getType()->getPointerElementType() != elTy) - replacement = bb.CreatePointerCast( - replacement, - PointerType::getUnqual( - anti->getType()->getPointerElementType())); - } -#endif - if (int AS = cast(anti->getType()) - ->getAddressSpace()) { - llvm::PointerType *PT; -#if LLVM_VERSION_MAJOR < 17 - if (call.getContext().supportsTypedPointers()) { - PT = PointerType::get( - anti->getType()->getPointerElementType(), AS); -#endif -#if LLVM_VERSION_MAJOR < 17 - } else { -#endif - PT = PointerType::get(anti->getContext(), AS); -#if LLVM_VERSION_MAJOR < 17 - } -#endif - replacement = bb.CreateAddrSpaceCast(replacement, PT); - cast(replacement) - ->setMetadata( - "enzyme_backstack", - MDNode::get(replacement->getContext(), {})); - } - gutils->replaceAWithB(cast(anti), replacement); - bb.SetInsertPoint(cast(anti)->getNextNode()); - gutils->erase(cast(anti)); - return replacement; - }; - - auto replacement = - applyChainRule(call.getType(), bb, rule, anti); - anti = replacement; - } - } - - if (Mode == DerivativeMode::ReverseModeCombined || - (Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) || - (Mode == DerivativeMode::ReverseModeGradient && - backwardsShadow) || - (Mode == DerivativeMode::ForwardModeSplit && backwardsShadow)) { - if (!inLoop) { - assert(zeroed); - } - } - } - gutils->invertedPointers.insert( - std::make_pair(&call, InvertedPointerVH(gutils, anti))); - } - endAnti:; - if (((Mode == DerivativeMode::ReverseModeCombined && shouldFree()) || - (Mode == DerivativeMode::ReverseModeGradient && shouldFree()) || - (Mode == DerivativeMode::ForwardModeSplit && shouldFree())) && - !isAlloca) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - assert(anti); - Value *tofree = lookup(anti, Builder2); - assert(tofree); - assert(tofree->getType()); - auto rule = [&](Value *tofree) { - auto CI = freeKnownAllocation(Builder2, tofree, funcName, dbgLoc, - gutils->TLI, &call, gutils); - if (CI) - CI->addAttributeAtIndex(AttributeList::FirstArgIndex, - Attribute::NonNull); - }; - applyChainRule(Builder2, rule, tofree); - } - } else if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - IRBuilder<> Builder2(&call); - getForwardBuilder(Builder2); - - SmallVector args; - for (unsigned i = 0; i < call.arg_size(); ++i) { - auto arg = call.getArgOperand(i); - args.push_back(gutils->getNewFromOriginal(arg)); - } - - uint64_t idx = 0; - Value *prev = gutils->getNewFromOriginal(&call); - auto rule = [&]() { - SmallVector BundleTypes(args.size(), ValueType::Primal); - - auto Defs = gutils->getInvertedBundles(&call, BundleTypes, Builder2, - /*lookup*/ false); - - CallInst *CI = Builder2.CreateCall( - call.getFunctionType(), call.getCalledFunction(), args, Defs); - CI->setAttributes(call.getAttributes()); - CI->setCallingConv(call.getCallingConv()); - CI->setTailCallKind(call.getTailCallKind()); - CI->setDebugLoc(dbgLoc); - - if (funcName == "julia.gc_alloc_obj" || - funcName == "jl_gc_alloc_typed" || - funcName == "ijl_gc_alloc_typed") { - if (EnzymeShadowAllocRewrite) { - bool used = unnecessaryInstructions.find(&call) == - unnecessaryInstructions.end(); - EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx, - wrap(prev), used); - } - } - idx++; - prev = CI; - return CI; - }; - - Value *CI = applyChainRule(call.getType(), Builder2, rule); - - auto found = gutils->invertedPointers.find(&call); - PHINode *placeholder = cast(&*found->second); - - gutils->invertedPointers.erase(found); - gutils->replaceAWithB(placeholder, CI); - gutils->erase(placeholder); - gutils->invertedPointers.insert( - std::make_pair(&call, InvertedPointerVH(gutils, CI))); - } - } - - // Cache and rematerialization irrelevant for forward mode. - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - eraseIfUnused(call); - return true; - } - - std::map Seen; - for (auto pair : gutils->knownRecomputeHeuristic) - if (!pair.second) - Seen[UsageKey(pair.first, QueryType::Primal)] = false; - bool primalNeededInReverse = - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError - ? false - : DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &call, Mode, Seen, oldUnreachable); - - bool cacheWholeAllocation = gutils->needsCacheWholeAllocation(&call); - if (cacheWholeAllocation) { - primalNeededInReverse = true; - } - - auto restoreFromStack = [&](MDNode *MD) { - IRBuilder<> B(newCall); - Value *Size; - if (funcName == "malloc") - Size = call.getArgOperand(0); - else if (funcName == "julia.gc_alloc_obj" || - funcName == "jl_gc_alloc_typed" || - funcName == "ijl_gc_alloc_typed") - Size = call.getArgOperand(1); - else - llvm_unreachable("Unknown allocation to upgrade"); - Size = gutils->getNewFromOriginal(Size); - - if (isa(Size)) { - B.SetInsertPoint(gutils->inversionAllocs); - } - Type *elTy = Type::getInt8Ty(call.getContext()); - Instruction *I = nullptr; -#if LLVM_VERSION_MAJOR < 17 - if (call.getContext().supportsTypedPointers()) { - for (auto U : call.users()) { - if (hasMetadata(cast(U), "enzyme_caststack")) { - elTy = U->getType()->getPointerElementType(); - Value *tsize = ConstantInt::get(Size->getType(), - (gutils->newFunc->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(elTy) + - 7) / - 8); - Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true); - I = gutils->getNewFromOriginal(cast(U)); - break; - } - } - } -#endif - Value *replacement = B.CreateAlloca(elTy, Size); - for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", - "enzymejl_allocart"}) - if (auto M = call.getMetadata(MD)) - cast(replacement)->setMetadata(MD, M); - if (I) - replacement->takeName(I); - else - replacement->takeName(newCall); - auto Alignment = - cast( - cast(MD->getOperand(0))->getValue()) - ->getLimitedValue(); - // Don't set zero alignment - if (Alignment) { - cast(replacement)->setAlignment(Align(Alignment)); - } -#if LLVM_VERSION_MAJOR < 17 - if (call.getContext().supportsTypedPointers()) { - if (call.getType()->getPointerElementType() != elTy) - replacement = B.CreatePointerCast( - replacement, - PointerType::getUnqual(call.getType()->getPointerElementType())); - } -#endif - if (int AS = cast(call.getType())->getAddressSpace()) { - llvm::PointerType *PT; -#if LLVM_VERSION_MAJOR < 17 - if (call.getContext().supportsTypedPointers()) { - PT = PointerType::get(call.getType()->getPointerElementType(), AS); -#endif -#if LLVM_VERSION_MAJOR < 17 - } else { -#endif - PT = PointerType::get(call.getContext(), AS); -#if LLVM_VERSION_MAJOR < 17 - } -#endif - replacement = B.CreateAddrSpaceCast(replacement, PT); - cast(replacement) - ->setMetadata("enzyme_backstack", - MDNode::get(replacement->getContext(), {})); - } - gutils->replaceAWithB(newCall, replacement); - gutils->erase(newCall); - }; - - // Don't erase any allocation that is being rematerialized. - { - auto found = gutils->rematerializableAllocations.find(&call); - if (found != gutils->rematerializableAllocations.end()) { - // If rematerializing (e.g. needed in reverse, but not needing - // the whole allocation): - if (primalNeededInReverse && !cacheWholeAllocation) { - assert(!unnecessaryValues.count(&call)); - // if rematerialize, don't ever cache and downgrade to stack - // allocation where possible. Note that for allocations which are - // within a loop, we will create the rematerialized allocation in the - // rematerialied loop. Note that what matters here is whether the - // actual call itself here is inside the loop, not whether the - // rematerialization is loop level. This is because one can have a - // loop level cache, but a function level allocation (e.g. for stack - // allocas). If we deleted it here, we would have no allocation! - auto AllocationLoop = gutils->OrigLI->getLoopFor(call.getParent()); - // An allocation within a loop, must definitionally be a loop level - // allocation (but not always the other way around. - if (AllocationLoop) - assert(found->second.LI); - if (auto MD = hasMetadata(&call, "enzyme_fromstack")) { - if (Mode == DerivativeMode::ReverseModeGradient && AllocationLoop) { - gutils->rematerializedPrimalOrShadowAllocations.push_back( - newCall); - } else { - restoreFromStack(MD); - } - return true; - } - - // No need to free GC. - if (EnzymeJuliaAddrLoad && isa(call.getType()) && - cast(call.getType())->getAddressSpace() == 10) { - if (Mode == DerivativeMode::ReverseModeGradient && AllocationLoop) - gutils->rematerializedPrimalOrShadowAllocations.push_back( - newCall); - return true; - } - - // Otherwise if in reverse pass, free the newly created allocation. - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardModeSplit) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - auto dbgLoc = gutils->getNewFromOriginal(call.getDebugLoc()); - freeKnownAllocation(Builder2, lookup(newCall, Builder2), funcName, - dbgLoc, gutils->TLI, &call, gutils); - if (Mode == DerivativeMode::ReverseModeGradient && AllocationLoop) - gutils->rematerializedPrimalOrShadowAllocations.push_back( - newCall); - return true; - } - // If in primal, do nothing (keeping the original caching behavior) - if (Mode == DerivativeMode::ReverseModePrimal) - return true; - } else if (!cacheWholeAllocation) { - if (unnecessaryValues.count(&call)) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - // If not caching allocation and not needed in the reverse, we can - // use the original freeing behavior for the function. If in the - // reverse pass we should not recreate this allocation. - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - else if (auto MD = hasMetadata(&call, "enzyme_fromstack")) { - restoreFromStack(MD); - } - return true; - } - } - } - - // If an allocation is not needed in the reverse, maintain the original - // free behavior and do not rematerialize this for the reverse. However, - // this is only safe to perform for allocations with a guaranteed free - // as can we can only guarantee that we don't erase those frees. - bool hasPDFree = gutils->allocationsWithGuaranteedFree.count(&call); - if (!primalNeededInReverse && hasPDFree) { - if (unnecessaryValues.count(&call)) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } else { - if (auto MD = hasMetadata(&call, "enzyme_fromstack")) { - restoreFromStack(MD); - } - } - return true; - } - - // If an object is managed by the GC do not preserve it for later free, - // Thus it only needs caching if there is a need for it in the reverse. - if (EnzymeJuliaAddrLoad && isa(call.getType()) && - cast(call.getType())->getAddressSpace() == 10) { - if (!subretused) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - if (!primalNeededInReverse) { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit) { - auto pn = BuilderZ.CreatePHI(call.getType(), 1, - call.getName() + "_replacementJ"); - gutils->fictiousPHIs[pn] = &call; - gutils->replaceAWithB(newCall, pn); - gutils->erase(newCall); - } - } else if (Mode != DerivativeMode::ReverseModeCombined) { - gutils->cacheForReverse(BuilderZ, newCall, - getIndex(&call, CacheType::Self, BuilderZ)); - } - return true; - } - - if (EnzymeFreeInternalAllocations) - hasPDFree = true; - - // TODO enable this if we need to free the memory - // NOTE THAT TOPLEVEL IS THERE SIMPLY BECAUSE THAT WAS PREVIOUS ATTITUTE - // TO FREE'ing - if ((primalNeededInReverse && - !gutils->unnecessaryIntermediates.count(&call)) || - hasPDFree) { - Value *nop = gutils->cacheForReverse( - BuilderZ, newCall, getIndex(&call, CacheType::Self, BuilderZ)); - if (hasPDFree && - ((Mode == DerivativeMode::ReverseModeGradient && shouldFree()) || - Mode == DerivativeMode::ReverseModeCombined || - (Mode == DerivativeMode::ForwardModeSplit && shouldFree()))) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - auto dbgLoc = gutils->getNewFromOriginal(call.getDebugLoc()); - freeKnownAllocation(Builder2, lookup(nop, Builder2), funcName, dbgLoc, - gutils->TLI, &call, gutils); - } - } else if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardModeSplit) { - // Note that here we cannot simply replace with null as users who - // try to find the shadow pointer will use the shadow of null rather - // than the true shadow of this - auto pn = BuilderZ.CreatePHI(call.getType(), 1, - call.getName() + "_replacementB"); - gutils->fictiousPHIs[pn] = &call; - gutils->replaceAWithB(newCall, pn); - gutils->erase(newCall); - } - - return true; - } - - if (funcName == "julia.gc_loaded") { - if (gutils->isConstantValue(&call)) { - eraseIfUnused(call); - return true; - } - auto ifound = gutils->invertedPointers.find(&call); - assert(ifound != gutils->invertedPointers.end()); - - auto placeholder = cast(&*ifound->second); - - bool needShadow = - DifferentialUseAnalysis::is_value_needed_in_reverse( - gutils, &call, Mode, oldUnreachable); - if (!needShadow) { - gutils->invertedPointers.erase(ifound); - gutils->erase(placeholder); - eraseIfUnused(call); - return true; - } - - Value *ptr0shadow = gutils->invertPointerM(call.getArgOperand(0), BuilderZ); - Value *ptr1shadow = gutils->invertPointerM(call.getArgOperand(1), BuilderZ); - - Value *val = applyChainRule( - call.getType(), BuilderZ, - [&](Value *v1, Value *v2) -> Value * { - Value *args[2] = {v1, v2}; - return BuilderZ.CreateCall(called, args); - }, - ptr0shadow, ptr1shadow); - - gutils->replaceAWithB(placeholder, val); - gutils->erase(placeholder); - eraseIfUnused(call); - return true; - } - - if (funcName == "julia.pointer_from_objref") { - if (gutils->isConstantValue(&call)) { - eraseIfUnused(call); - return true; - } - - auto ifound = gutils->invertedPointers.find(&call); - assert(ifound != gutils->invertedPointers.end()); - - auto placeholder = cast(&*ifound->second); - - bool needShadow = - DifferentialUseAnalysis::is_value_needed_in_reverse( - gutils, &call, Mode, oldUnreachable); - if (!needShadow) { - gutils->invertedPointers.erase(ifound); - gutils->erase(placeholder); - eraseIfUnused(call); - return true; - } - - Value *ptrshadow = gutils->invertPointerM(call.getArgOperand(0), BuilderZ); - - Value *val = applyChainRule( - call.getType(), BuilderZ, - [&](Value *v) -> Value * { return BuilderZ.CreateCall(called, {v}); }, - ptrshadow); - - gutils->replaceAWithB(placeholder, val); - gutils->erase(placeholder); - eraseIfUnused(call); - return true; - } - if (funcName.contains("__enzyme_todense")) { - if (gutils->isConstantValue(&call)) { - eraseIfUnused(call); - return true; - } - - auto ifound = gutils->invertedPointers.find(&call); - assert(ifound != gutils->invertedPointers.end()); - - auto placeholder = cast(&*ifound->second); - - bool needShadow = - DifferentialUseAnalysis::is_value_needed_in_reverse( - gutils, &call, Mode, oldUnreachable); - if (!needShadow) { - gutils->invertedPointers.erase(ifound); - gutils->erase(placeholder); - eraseIfUnused(call); - return true; - } - - SmallVector args; - for (size_t i = 0; i < 2; i++) - args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i))); - for (size_t i = 2; i < call.arg_size(); ++i) - args.push_back(gutils->invertPointerM(call.getArgOperand(0), BuilderZ)); - - Value *res = UndefValue::get(gutils->getShadowType(call.getType())); - if (gutils->getWidth() == 1) { - res = BuilderZ.CreateCall(called, args); - } else { - for (size_t w = 0; w < gutils->getWidth(); ++w) { - SmallVector targs = {args[0], args[1]}; - for (size_t i = 2; i < call.arg_size(); ++i) - targs.push_back(GradientUtils::extractMeta(BuilderZ, args[i], w)); - - auto tres = BuilderZ.CreateCall(called, targs); - res = BuilderZ.CreateInsertValue(res, tres, w); - } - } - - gutils->replaceAWithB(placeholder, res); - gutils->erase(placeholder); - eraseIfUnused(call); - return true; - } - - if (funcName == "memcpy" || funcName == "memmove") { - auto ID = (funcName == "memcpy") ? Intrinsic::memcpy : Intrinsic::memmove; - visitMemTransferCommon(ID, /*srcAlign*/ MaybeAlign(1), - /*dstAlign*/ MaybeAlign(1), call, - call.getArgOperand(0), call.getArgOperand(1), - gutils->getNewFromOriginal(call.getArgOperand(2)), - ConstantInt::getFalse(call.getContext())); - return true; - } - if (funcName == "memset" || funcName == "memset_pattern16" || - funcName == "__memset_chk") { - visitMemSetCommon(call); - return true; - } - if (funcName == "enzyme_zerotype") { - IRBuilder<> BuilderZ(&call); - getForwardBuilder(BuilderZ); - - bool forceErase = Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ForwardModeSplit; - - if (forceErase) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - else - eraseIfUnused(call); - - Value *orig_op0 = call.getArgOperand(0); - - // If constant destination then no operation needs doing - if (gutils->isConstantValue(orig_op0)) { - return true; - } - - if (!forceErase) { - Value *op0 = gutils->invertPointerM(orig_op0, BuilderZ); - Value *op1 = gutils->getNewFromOriginal(call.getArgOperand(1)); - Value *op2 = gutils->getNewFromOriginal(call.getArgOperand(2)); - auto Defs = gutils->getInvertedBundles( - &call, {ValueType::Shadow, ValueType::Primal, ValueType::Primal}, - BuilderZ, /*lookup*/ false); - - applyChainRule( - BuilderZ, - [&](Value *op0) { - SmallVector args = {op0, op1, op2}; - auto cal = - BuilderZ.CreateCall(call.getCalledFunction(), args, Defs); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - cal->copyMetadata(call, ToCopy2); - cal->setAttributes(call.getAttributes()); - if (auto m = hasMetadata(&call, "enzyme_zerostack")) - cal->setMetadata("enzyme_zerostack", m); - cal->setCallingConv(call.getCallingConv()); - cal->setTailCallKind(call.getTailCallKind()); - cal->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc())); - }, - op0); - } - return true; - } - if (funcName == "cuStreamCreate") { - Value *val = nullptr; - llvm::Type *PT = getInt8PtrTy(call.getContext()); -#if LLVM_VERSION_MAJOR < 17 - if (call.getContext().supportsTypedPointers()) { - if (isa(call.getArgOperand(0)->getType())) - PT = call.getArgOperand(0)->getType()->getPointerElementType(); - } -#endif - if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined) { - val = gutils->getNewFromOriginal(call.getOperand(0)); - if (!isa(val->getType())) - val = BuilderZ.CreateIntToPtr(val, PointerType::getUnqual(PT)); - val = BuilderZ.CreateLoad(PT, val); - val = gutils->cacheForReverse(BuilderZ, val, - getIndex(&call, CacheType::Tape, BuilderZ)); - - } else if (Mode == DerivativeMode::ReverseModeGradient) { - PHINode *toReplace = - BuilderZ.CreatePHI(PT, 1, call.getName() + "_psxtmp"); - val = gutils->cacheForReverse(BuilderZ, toReplace, - getIndex(&call, CacheType::Tape, BuilderZ)); - } - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - val = gutils->lookupM(val, Builder2); - auto FreeFunc = gutils->newFunc->getParent()->getOrInsertFunction( - "cuStreamDestroy", call.getType(), PT); - Value *nargs[] = {val}; - Builder2.CreateCall(FreeFunc, nargs); - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - if (funcName == "cuStreamDestroy") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - if (funcName == "cuStreamSynchronize") { - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - Value *nargs[] = {gutils->lookupM( - gutils->getNewFromOriginal(call.getOperand(0)), Builder2)}; - auto callval = call.getCalledOperand(); - Builder2.CreateCall(call.getFunctionType(), callval, nargs); - } - if (Mode == DerivativeMode::ReverseModeGradient) - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - if (funcName == "posix_memalign" || funcName == "cuMemAllocAsync" || - funcName == "cuMemAlloc" || funcName == "cuMemAlloc_v2" || - funcName == "cudaMalloc" || funcName == "cudaMallocAsync" || - funcName == "cudaMallocHost" || funcName == "cudaMallocFromPoolAsync") { - bool constval = gutils->isConstantInstruction(&call); - - Value *val; - llvm::Type *PT = getInt8PtrTy(call.getContext()); -#if LLVM_VERSION_MAJOR < 17 - if (call.getContext().supportsTypedPointers()) { - if (isa(call.getArgOperand(0)->getType())) - PT = call.getArgOperand(0)->getType()->getPointerElementType(); - } -#endif - if (!constval) { - Value *stream = nullptr; - if (funcName == "cuMemAllocAsync") - stream = gutils->getNewFromOriginal(call.getArgOperand(2)); - else if (funcName == "cudaMallocAsync") - stream = gutils->getNewFromOriginal(call.getArgOperand(2)); - else if (funcName == "cudaMallocFromPoolAsync") - stream = gutils->getNewFromOriginal(call.getArgOperand(3)); - - auto M = gutils->newFunc->getParent(); - - if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - Value *ptrshadow = - gutils->invertPointerM(call.getArgOperand(0), BuilderZ); - SmallVector args; - SmallVector valtys; - args.push_back(ptrshadow); - valtys.push_back(ValueType::Shadow); - for (size_t i = 1; i < call.arg_size(); ++i) { - args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i))); - valtys.push_back(ValueType::Primal); - } - - auto Defs = gutils->getInvertedBundles(&call, valtys, BuilderZ, - /*lookup*/ false); - - val = applyChainRule( - PT, BuilderZ, - [&](Value *ptrshadow) { - args[0] = ptrshadow; - - BuilderZ.CreateCall(called, args, Defs); - if (!isa(ptrshadow->getType())) - ptrshadow = BuilderZ.CreateIntToPtr(ptrshadow, - PointerType::getUnqual(PT)); - Value *val = BuilderZ.CreateLoad(PT, ptrshadow); - - auto dst_arg = - BuilderZ.CreateBitCast(val, getInt8PtrTy(call.getContext())); - - auto val_arg = - ConstantInt::get(Type::getInt8Ty(call.getContext()), 0); - auto len_arg = gutils->getNewFromOriginal( - call.getArgOperand((funcName == "posix_memalign") ? 2 : 1)); - - if (funcName == "posix_memalign" || - funcName == "cudaMallocHost") { - BuilderZ.CreateMemSet(dst_arg, val_arg, len_arg, MaybeAlign()); - } else if (funcName == "cudaMalloc") { - Type *tys[] = {PT, val_arg->getType(), len_arg->getType()}; - auto F = M->getOrInsertFunction( - "cudaMemset", - FunctionType::get(call.getType(), tys, false)); - Value *nargs[] = {dst_arg, val_arg, len_arg}; - auto memset = cast(BuilderZ.CreateCall(F, nargs)); - memset->addParamAttr(0, Attribute::NonNull); - } else if (funcName == "cudaMallocAsync" || - funcName == "cudaMallocFromPoolAsync") { - Type *tys[] = {PT, val_arg->getType(), len_arg->getType(), - stream->getType()}; - auto F = M->getOrInsertFunction( - "cudaMemsetAsync", - FunctionType::get(call.getType(), tys, false)); - Value *nargs[] = {dst_arg, val_arg, len_arg, stream}; - auto memset = cast(BuilderZ.CreateCall(F, nargs)); - memset->addParamAttr(0, Attribute::NonNull); - } else if (funcName == "cuMemAllocAsync") { - Type *tys[] = {PT, val_arg->getType(), len_arg->getType(), - stream->getType()}; - auto F = M->getOrInsertFunction( - "cuMemsetD8Async", - FunctionType::get(call.getType(), tys, false)); - Value *nargs[] = {dst_arg, val_arg, len_arg, stream}; - auto memset = cast(BuilderZ.CreateCall(F, nargs)); - memset->addParamAttr(0, Attribute::NonNull); - } else if (funcName == "cuMemAlloc" || - funcName == "cuMemAlloc_v2") { - Type *tys[] = {PT, val_arg->getType(), len_arg->getType()}; - auto F = M->getOrInsertFunction( - "cuMemsetD8", - FunctionType::get(call.getType(), tys, false)); - Value *nargs[] = {dst_arg, val_arg, len_arg}; - auto memset = cast(BuilderZ.CreateCall(F, nargs)); - memset->addParamAttr(0, Attribute::NonNull); - } else { - llvm_unreachable("unhandled allocation"); - } - return val; - }, - ptrshadow); - - if (Mode != DerivativeMode::ForwardMode && - Mode != DerivativeMode::ForwardModeError) - val = gutils->cacheForReverse( - BuilderZ, val, getIndex(&call, CacheType::Tape, BuilderZ)); - } else if (Mode == DerivativeMode::ReverseModeGradient) { - PHINode *toReplace = BuilderZ.CreatePHI(gutils->getShadowType(PT), 1, - call.getName() + "_psxtmp"); - val = gutils->cacheForReverse( - BuilderZ, toReplace, getIndex(&call, CacheType::Tape, BuilderZ)); - } - - if (Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModeGradient) { - if (shouldFree()) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - Value *tofree = gutils->lookupM(val, Builder2, ValueToValueMapTy(), - /*tryLegalRecompute*/ false); - - Type *VoidTy = Type::getVoidTy(M->getContext()); - Type *IntPtrTy = getInt8PtrTy(M->getContext()); - - Value *streamL = nullptr; - if (stream) - streamL = gutils->lookupM(stream, Builder2); - - applyChainRule( - BuilderZ, - [&](Value *tofree) { - if (funcName == "posix_memalign") { - auto FreeFunc = - M->getOrInsertFunction("free", VoidTy, IntPtrTy); - Builder2.CreateCall(FreeFunc, tofree); - } else if (funcName == "cuMemAllocAsync") { - auto FreeFunc = M->getOrInsertFunction( - "cuMemFreeAsync", VoidTy, IntPtrTy, streamL->getType()); - Value *nargs[] = {tofree, streamL}; - Builder2.CreateCall(FreeFunc, nargs); - } else if (funcName == "cuMemAlloc" || - funcName == "cuMemAlloc_v2") { - auto FreeFunc = - M->getOrInsertFunction("cuMemFree", VoidTy, IntPtrTy); - Value *nargs[] = {tofree}; - Builder2.CreateCall(FreeFunc, nargs); - } else if (funcName == "cudaMalloc") { - auto FreeFunc = - M->getOrInsertFunction("cudaFree", VoidTy, IntPtrTy); - Value *nargs[] = {tofree}; - Builder2.CreateCall(FreeFunc, nargs); - } else if (funcName == "cudaMallocAsync" || - funcName == "cudaMallocFromPoolAsync") { - auto FreeFunc = M->getOrInsertFunction( - "cudaFreeAsync", VoidTy, IntPtrTy, streamL->getType()); - Value *nargs[] = {tofree, streamL}; - Builder2.CreateCall(FreeFunc, nargs); - } else if (funcName == "cudaMallocHost") { - auto FreeFunc = - M->getOrInsertFunction("cudaFreeHost", VoidTy, IntPtrTy); - Value *nargs[] = {tofree}; - Builder2.CreateCall(FreeFunc, nargs); - } else - llvm_unreachable("unknown function to free"); - }, - tofree); - } - } - } - - // TODO enable this if we need to free the memory - // NOTE THAT TOPLEVEL IS THERE SIMPLY BECAUSE THAT WAS PREVIOUS ATTITUTE - // TO FREE'ing - if (Mode == DerivativeMode::ReverseModeGradient) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } else if (Mode == DerivativeMode::ReverseModePrimal) { - // if (is_value_needed_in_reverse( - // TR, gutils, orig, /*topLevel*/ Mode == - // DerivativeMode::Both)) - // { - - // gutils->cacheForReverse(BuilderZ, newCall, - // getIndex(orig, CacheType::Self, BuilderZ)); - //} else if (Mode != DerivativeMode::Forward) { - // Note that here we cannot simply replace with null as users who try - // to find the shadow pointer will use the shadow of null rather than - // the true shadow of this - //} - } else if (Mode == DerivativeMode::ReverseModeCombined && shouldFree()) { - IRBuilder<> Builder2(newCall->getNextNode()); - auto ptrv = gutils->getNewFromOriginal(call.getOperand(0)); - if (!isa(ptrv->getType())) - ptrv = BuilderZ.CreateIntToPtr(ptrv, PointerType::getUnqual(PT)); - auto load = Builder2.CreateLoad(PT, ptrv, "posix_preread"); - Builder2.SetInsertPoint(&call); - getReverseBuilder(Builder2); - auto tofree = gutils->lookupM(load, Builder2, ValueToValueMapTy(), - /*tryLegal*/ false); - Value *streamL = nullptr; - if (funcName == "cuMemAllocAsync") - streamL = gutils->getNewFromOriginal(call.getArgOperand(2)); - else if (funcName == "cudaMallocAsync") - streamL = gutils->getNewFromOriginal(call.getArgOperand(2)); - else if (funcName == "cudaMallocFromPoolAsync") - streamL = gutils->getNewFromOriginal(call.getArgOperand(3)); - if (streamL) - streamL = gutils->lookupM(streamL, Builder2); - - auto M = gutils->newFunc->getParent(); - Type *VoidTy = Type::getVoidTy(M->getContext()); - Type *IntPtrTy = getInt8PtrTy(M->getContext()); - - if (funcName == "posix_memalign") { - auto FreeFunc = M->getOrInsertFunction("free", VoidTy, IntPtrTy); - Builder2.CreateCall(FreeFunc, tofree); - } else if (funcName == "cuMemAllocAsync") { - auto FreeFunc = M->getOrInsertFunction("cuMemFreeAsync", VoidTy, - IntPtrTy, streamL->getType()); - Value *nargs[] = {tofree, streamL}; - Builder2.CreateCall(FreeFunc, nargs); - } else if (funcName == "cuMemAlloc" || funcName == "cuMemAlloc_v2") { - auto FreeFunc = M->getOrInsertFunction("cuMemFree", VoidTy, IntPtrTy); - Value *nargs[] = {tofree}; - Builder2.CreateCall(FreeFunc, nargs); - } else if (funcName == "cudaMalloc") { - auto FreeFunc = M->getOrInsertFunction("cudaFree", VoidTy, IntPtrTy); - Value *nargs[] = {tofree}; - Builder2.CreateCall(FreeFunc, nargs); - } else if (funcName == "cudaMallocAsync" || - funcName == "cudaMallocFromPoolAsync") { - auto FreeFunc = M->getOrInsertFunction("cudaFreeAsync", VoidTy, - IntPtrTy, streamL->getType()); - Value *nargs[] = {tofree, streamL}; - Builder2.CreateCall(FreeFunc, nargs); - } else if (funcName == "cudaMallocHost") { - auto FreeFunc = - M->getOrInsertFunction("cudaFreeHost", VoidTy, IntPtrTy); - Value *nargs[] = {tofree}; - Builder2.CreateCall(FreeFunc, nargs); - } else - llvm_unreachable("unknown function to free"); - } - - return true; - } - - // Remove free's in forward pass so the memory can be used in the reverse - // pass - if (isDeallocationFunction(funcName, gutils->TLI)) { - assert(gutils->invertedPointers.find(&call) == - gutils->invertedPointers.end()); - - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - if (!gutils->isConstantValue(call.getArgOperand(0))) { - IRBuilder<> Builder2(&call); - getForwardBuilder(Builder2); - auto origfree = call.getArgOperand(0); - auto newfree = gutils->getNewFromOriginal(call.getArgOperand(0)); - auto tofree = gutils->invertPointerM(origfree, Builder2); - - Function *free = getOrInsertCheckedFree( - *call.getModule(), &call, newfree->getType(), gutils->getWidth()); - - bool used = true; - if (auto instArg = dyn_cast(call.getArgOperand(0))) - used = unnecessaryInstructions.find(instArg) == - unnecessaryInstructions.end(); - - SmallVector args; - if (used) - args.push_back(newfree); - else - args.push_back( - Constant::getNullValue(call.getArgOperand(0)->getType())); - - auto rule = [&args](Value *tofree) { args.push_back(tofree); }; - applyChainRule(Builder2, rule, tofree); - - for (size_t i = 1; i < call.arg_size(); i++) { - args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i))); - } - - auto frees = Builder2.CreateCall(free->getFunctionType(), free, args); - frees->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc())); - - eraseIfUnused(call); - return true; - } - eraseIfUnused(call); - } - auto callval = call.getCalledOperand(); - - for (auto rmat : gutils->backwardsOnlyShadows) { - if (rmat.second.frees.count(&call)) { - bool shouldFree = false; - if (rmat.second.primalInitialize) { - if (Mode == DerivativeMode::ReverseModePrimal) - shouldFree = true; - } - - if (shouldFree) { - IRBuilder<> Builder2(&call); - getForwardBuilder(Builder2); - auto origfree = call.getArgOperand(0); - auto tofree = gutils->invertPointerM(origfree, Builder2); - if (tofree != origfree) { - SmallVector args = {tofree}; - CallInst *CI = - Builder2.CreateCall(call.getFunctionType(), callval, args); - CI->setAttributes(call.getAttributes()); - } - } - break; - } - } - - // If a rematerializable allocation. - for (auto rmat : gutils->rematerializableAllocations) { - if (rmat.second.frees.count(&call)) { - // Leave the original free behavior since this won't be used - // in the reverse pass in split mode - if (Mode == DerivativeMode::ReverseModePrimal) { - eraseIfUnused(call); - return true; - } else if (Mode == DerivativeMode::ReverseModeGradient) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } else { - assert(Mode == DerivativeMode::ReverseModeCombined); - std::map Seen; - for (auto pair : gutils->knownRecomputeHeuristic) - if (!pair.second) - Seen[UsageKey(pair.first, QueryType::Primal)] = false; - bool primalNeededInReverse = - DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, rmat.first, Mode, Seen, - oldUnreachable); - bool cacheWholeAllocation = - gutils->needsCacheWholeAllocation(rmat.first); - if (cacheWholeAllocation) { - primalNeededInReverse = true; - } - // If in a loop context, maintain the same free behavior, unless - // caching whole allocation. - if (!cacheWholeAllocation) { - eraseIfUnused(call); - return true; - } - assert(!unnecessaryValues.count(rmat.first)); - (void)primalNeededInReverse; - assert(primalNeededInReverse); - } - } - } - - if (gutils->forwardDeallocations.count(&call)) { - if (Mode == DerivativeMode::ReverseModeGradient) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - } else - eraseIfUnused(call); - return true; - } - - if (gutils->postDominatingFrees.count(&call)) { - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - - llvm::Value *val = getBaseObject(call.getArgOperand(0)); - if (isa(val)) { - llvm::errs() << "removing free of null pointer\n"; - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - - // TODO HANDLE FREE - llvm::errs() << "freeing without malloc " << *val << "\n"; - eraseIfUnused(call, /*erase*/ true, /*check*/ false); - return true; - } - - if (call.hasFnAttr("enzyme_sample")) { - if (Mode != DerivativeMode::ReverseModeCombined && - Mode != DerivativeMode::ReverseModeGradient) - return true; - - bool constval = gutils->isConstantInstruction(&call); - - if (constval) - return true; - - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - auto trace = call.getArgOperand(call.arg_size() - 1); - auto address = call.getArgOperand(0); - - auto dtrace = lookup(gutils->getNewFromOriginal(trace), Builder2); - auto daddress = lookup(gutils->getNewFromOriginal(address), Builder2); - - Value *dchoice; - if (TR.query(&call)[{-1}].isPossiblePointer()) { - dchoice = gutils->invertPointerM(&call, Builder2); - } else { - dchoice = diffe(&call, Builder2); - } - - if (call.hasMetadata("enzyme_gradient_setter")) { - auto gradient_setter = cast( - cast( - call.getMetadata("enzyme_gradient_setter")->getOperand(0).get()) - ->getValue()); - - TraceUtils::InsertChoiceGradient( - Builder2, gradient_setter->getFunctionType(), gradient_setter, - daddress, dchoice, dtrace); - } - - return true; - } - - if (call.hasFnAttr("enzyme_insert_argument")) { - IRBuilder<> Builder2(&call); - getReverseBuilder(Builder2); - - auto name = call.getArgOperand(0); - auto arg = call.getArgOperand(1); - auto trace = call.getArgOperand(2); - - auto gradient_setter = cast( - cast( - call.getMetadata("enzyme_gradient_setter")->getOperand(0).get()) - ->getValue()); - - auto dtrace = lookup(gutils->getNewFromOriginal(trace), Builder2); - auto dname = lookup(gutils->getNewFromOriginal(name), Builder2); - Value *darg; - - if (TR.query(arg)[{-1}].isPossiblePointer()) { - darg = gutils->invertPointerM(arg, Builder2); - } else { - darg = diffe(arg, Builder2); - } - - TraceUtils::InsertArgumentGradient(Builder2, - gradient_setter->getFunctionType(), - gradient_setter, dname, darg, dtrace); - return true; - } - - return false; -} diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index 77fecd109c82..f32eca62e03d 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -53,11 +53,18 @@ class EnzymeAction final : public clang::PluginASTAction { std::unique_ptr CreateASTConsumer(clang::CompilerInstance &CI, llvm::StringRef InFile) override { +llvm::errs() << " create consumer\n"; +llvm::errs() << " is device: " << CI.getLangOpts().CUDAIsDevice << "\n"; +llvm::errs() << " out file: " << CI.getFrontendOpts().OutputFile << "\n"; return std::unique_ptr(new ConsumerType(CI)); } bool ParseArgs(const clang::CompilerInstance &CI, const std::vector &args) override { + llvm::errs() << " parse args action\n"; + llvm::errs() << " pa: " << CI.getFrontendOpts().ProgramAction << "\n"; + llvm::errs() << " args:\n"; +for (auto a : args) llvm::errs() << "+ arg: " << a<<"\n"; return true; } @@ -80,9 +87,9 @@ struct Visitor : public RecursiveASTVisitor { } }; -#if LLVM_VERSION_MAJOR >= 18 -extern "C" void registerEnzyme(llvm::PassBuilder &PB); -#endif +extern "C" void registerReactant(llvm::PassBuilder &PB, std::vector gpubins); + +extern "C" void registerExporter(llvm::PassBuilder &PB, std::string file); class EnzymePlugin final : public clang::ASTConsumer { clang::CompilerInstance &CI; @@ -90,33 +97,38 @@ class EnzymePlugin final : public clang::ASTConsumer { public: EnzymePlugin(clang::CompilerInstance &CI) : CI(CI) { - FrontendOptions &Opts = CI.getFrontendOpts(); + llvm::errs() << " enzyme plugin constructor\n"; + //FrontendOptions &Opts = CI.getFrontendOpts(); CodeGenOptions &CGOpts = CI.getCodeGenOpts(); auto PluginName = "ClangEnzyme-" + std::to_string(LLVM_VERSION_MAJOR); - bool contains = false; -#if LLVM_VERSION_MAJOR < 18 - std::string pluginPath; -#endif - for (auto P : Opts.Plugins) - if (endsWith(llvm::sys::path::stem(P), PluginName)) { -#if LLVM_VERSION_MAJOR < 18 - pluginPath = P; -#endif - for (auto passPlugin : CGOpts.PassPlugins) { - if (endsWith(llvm::sys::path::stem(passPlugin), PluginName)) { - contains = true; - break; - } - } - } + //bool contains = false; - if (!contains) { -#if LLVM_VERSION_MAJOR >= 18 - CGOpts.PassBuilderCallbacks.push_back(registerEnzyme); -#else - CGOpts.PassPlugins.push_back(pluginPath); -#endif + + std::string inFile; + for (auto in : CI.getFrontendOpts().Inputs) { + if (in.isFile()) { + inFile = in.getFile().str(); + llvm::errs() << " in: " << in.getFile() << "\n"; + } } + if (CI.getLangOpts().CUDAIsDevice) { + std::string file = CI.getFrontendOpts().OutputFile; + file = inFile; + CGOpts.PassBuilderCallbacks.push_back([=](llvm::PassBuilder & PB) { + registerExporter(PB, file); + }); + } else { + std::vector gpubins; + if (CGOpts.CudaGpuBinaryFileName.size()) { + if (inFile.size()) + gpubins.push_back(inFile); + //gpubins.push_back(CGOpts.CudaGpuBinaryFileName); + } + CGOpts.PassBuilderCallbacks.push_back([=](llvm::PassBuilder &PB) { + registerReactant(PB, gpubins); + }); + } + CI.getPreprocessorOpts().Includes.push_back("/enzyme/enzyme/version"); std::string PredefineBuffer; @@ -166,50 +178,6 @@ class EnzymePlugin final : public clang::ASTConsumer { /*isAngled=*/true); } ~EnzymePlugin() {} - void HandleTranslationUnit(ASTContext &context) override {} - bool HandleTopLevelDecl(clang::DeclGroupRef dg) override { - using namespace clang; - DeclGroupRef::iterator it; - - // Visitor v(CI); - // Forcibly require emission of all libdevice - for (it = dg.begin(); it != dg.end(); ++it) { - // v.TraverseDecl(*it); - if (auto FD = dyn_cast(*it)) { - if (!FD->hasAttr()) - continue; - - if (!FD->getIdentifier()) - continue; - if (!StringRef(FD->getLocation().printToString(CI.getSourceManager())) - .contains("/__clang_cuda_math.h")) - continue; - - FD->addAttr(UsedAttr::CreateImplicit(CI.getASTContext())); - } - if (auto FD = dyn_cast(*it)) { - HandleCXXStaticMemberVarInstantiation(FD); - } - } - return true; - } - void HandleCXXStaticMemberVarInstantiation(clang::VarDecl *V) override { - if (!V->getIdentifier()) - return; - auto name = V->getName(); - if (!(name.contains("__enzyme_inactive_global") || - name.contains("__enzyme_inactivefn") || - name.contains("__enzyme_shouldrecompute") || - name.contains("__enzyme_function_like") || - name.contains("__enzyme_allocation_like") || - name.contains("__enzyme_register_gradient") || - name.contains("__enzyme_register_derivative") || - name.contains("__enzyme_register_splitderivative"))) - return; - - V->addAttr(clang::UsedAttr::CreateImplicit(CI.getASTContext())); - return; - } }; // register the PluginASTAction in the registry. diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp deleted file mode 100644 index e88e762e7f03..000000000000 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ /dev/null @@ -1,1213 +0,0 @@ -//===- DiffeGradientUtils.cpp - Helper class and utilities for AD ---------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file declares two helper classes GradientUtils and subclass -// DiffeGradientUtils. These classes contain utilities for managing the cache, -// recomputing statements, and in the case of DiffeGradientUtils, managing -// adjoint values and shadow pointers. -// -//===----------------------------------------------------------------------===// - -#include - -#include "DiffeGradientUtils.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -#include "llvm/Transforms/Utils/BasicBlockUtils.h" - -#include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" - -#include "LibraryFuncs.h" -#include "Utils.h" - -using namespace llvm; - -DiffeGradientUtils::DiffeGradientUtils( - EnzymeLogic &Logic, Function *newFunc_, Function *oldFunc_, - TargetLibraryInfo &TLI, TypeAnalysis &TA, TypeResults TR, - ValueToValueMapTy &invertedPointers_, - const SmallPtrSetImpl &constantvalues_, - const SmallPtrSetImpl &returnvals_, DIFFE_TYPE ActiveReturn, - bool shadowReturnUsed, ArrayRef constant_values, - llvm::ValueMap &origToNew_, - DerivativeMode mode, bool runtimeActivity, unsigned width, bool omp) - : GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_, - constantvalues_, returnvals_, ActiveReturn, - shadowReturnUsed, constant_values, origToNew_, mode, - runtimeActivity, width, omp) { - if (oldFunc_->empty()) - return; - assert(reverseBlocks.size() == 0); - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError || - mode == DerivativeMode::ForwardModeSplit) { - return; - } - for (BasicBlock *BB : originalBlocks) { - if (BB == inversionAllocs) - continue; - BasicBlock *RBB = - BasicBlock::Create(BB->getContext(), "invert" + BB->getName(), newFunc); - reverseBlocks[BB].push_back(RBB); - reverseBlockToPrimal[RBB] = BB; - } - assert(reverseBlocks.size() != 0); -} - -DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( - EnzymeLogic &Logic, DerivativeMode mode, bool runtimeActivity, - unsigned width, Function *todiff, TargetLibraryInfo &TLI, TypeAnalysis &TA, - FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool shadowReturn, - bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, Type *additionalArg, bool omp) { - Function *oldFunc = todiff; - assert(mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ReverseModeCombined || - mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit || - mode == DerivativeMode::ForwardModeError); - ValueToValueMapTy invertedPointers; - SmallPtrSet constants; - SmallPtrSet nonconstant; - SmallPtrSet returnvals; - llvm::ValueMap originalToNew; - - SmallPtrSet constant_values; - SmallPtrSet nonconstant_values; - - std::string prefix; - - switch (mode) { - case DerivativeMode::ForwardModeError: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeSplit: - prefix = "fwddiffe"; - break; - case DerivativeMode::ReverseModeCombined: - case DerivativeMode::ReverseModeGradient: - prefix = "diffe"; - break; - case DerivativeMode::ReverseModePrimal: - llvm_unreachable("invalid DerivativeMode: ReverseModePrimal\n"); - } - - if (width > 1) - prefix += std::to_string(width); - - auto newFunc = Logic.PPC.CloneFunctionWithReturns( - mode, width, oldFunc, invertedPointers, constant_args, constant_values, - nonconstant_values, returnvals, returnValue, retType, - prefix + oldFunc->getName(), &originalToNew, - /*diffeReturnArg*/ diffeReturnArg, additionalArg); - - // Convert overwritten args from the input function to the preprocessed - // function - - FnTypeInfo typeInfo(oldFunc); - { - auto toarg = todiff->arg_begin(); - auto olarg = oldFunc->arg_begin(); - for (; toarg != todiff->arg_end(); ++toarg, ++olarg) { - - { - auto fd = oldTypeInfo.Arguments.find(toarg); - assert(fd != oldTypeInfo.Arguments.end()); - typeInfo.Arguments.insert( - std::pair(olarg, fd->second)); - } - - { - auto cfd = oldTypeInfo.KnownValues.find(toarg); - assert(cfd != oldTypeInfo.KnownValues.end()); - typeInfo.KnownValues.insert( - std::pair>(olarg, cfd->second)); - } - } - typeInfo.Return = oldTypeInfo.Return; - } - - TypeResults TR = TA.analyzeFunction(typeInfo); - if (!oldFunc->empty()) - assert(TR.getFunction() == oldFunc); - - auto res = new DiffeGradientUtils( - Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values, - nonconstant_values, retType, shadowReturn, constant_args, originalToNew, - mode, runtimeActivity, width, omp); - - return res; -} - -AllocaInst *DiffeGradientUtils::getDifferential(Value *val) { - assert(mode != DerivativeMode::ForwardMode); - assert(mode != DerivativeMode::ForwardModeSplit); - assert(mode != DerivativeMode::ForwardModeError); - assert(val); -#ifndef NDEBUG - if (auto arg = dyn_cast(val)) - assert(arg->getParent() == oldFunc); - if (auto inst = dyn_cast(val)) - assert(inst->getParent()->getParent() == oldFunc); -#endif - assert(inversionAllocs); - - Type *type = getShadowType(val->getType()); - if (differentials.find(val) == differentials.end()) { - IRBuilder<> entryBuilder(inversionAllocs); - entryBuilder.setFastMathFlags(getFast()); - differentials[val] = - entryBuilder.CreateAlloca(type, nullptr, val->getName() + "'de"); - auto Alignment = - oldFunc->getParent()->getDataLayout().getPrefTypeAlign(type); - differentials[val]->setAlignment(Alignment); - ZeroMemory(entryBuilder, type, differentials[val], - /*isTape*/ false); - } -#if LLVM_VERSION_MAJOR < 17 - if (val->getContext().supportsTypedPointers()) { - assert(differentials[val]->getType()->getPointerElementType() == type); - } -#endif - return differentials[val]; -} - -Value *DiffeGradientUtils::diffe(Value *val, IRBuilder<> &BuilderM) { -#ifndef NDEBUG - if (auto arg = dyn_cast(val)) - assert(arg->getParent() == oldFunc); - if (auto inst = dyn_cast(val)) - assert(inst->getParent()->getParent() == oldFunc); -#endif - - if (isConstantValue(val)) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *val << "\n"; - assert(0 && "getting diffe of constant value"); - } - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit || - mode == DerivativeMode::ForwardModeError) - return invertPointerM(val, BuilderM); - if (val->getType()->isPointerTy()) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *val << "\n"; - } - assert(!val->getType()->isPointerTy()); - assert(!val->getType()->isVoidTy()); - Type *ty = getShadowType(val->getType()); - return BuilderM.CreateLoad(ty, getDifferential(val)); -} - -SmallVector DiffeGradientUtils::addToDiffe( - Value *val, Value *dif, IRBuilder<> &BuilderM, Type *addingType, - unsigned start, unsigned size, llvm::ArrayRef idxs, - llvm::Value *mask, size_t ignoreFirstSlicesOfDif) { - assert(addingType); - auto &DL = oldFunc->getParent()->getDataLayout(); - Type *VT = val->getType(); - for (auto cv : idxs) { - auto i = dyn_cast(cv)->getSExtValue(); - if (auto ST = dyn_cast(VT)) { - VT = ST->getElementType(i); - continue; - } - if (auto AT = dyn_cast(VT)) { - assert((size_t)i < AT->getNumElements()); - VT = AT->getElementType(); - continue; - } - assert(0 && "illegal indexing type"); - } - auto storeSize = (DL.getTypeSizeInBits(VT) + 7) / 8; - - assert(start < storeSize); - assert(start + size <= storeSize); - - // If VT is a struct type the addToDiffe algorithm will lose type information - // so we do the recurrence here, with full type information. - if (start == 0 && size == storeSize && !isa(VT)) { - if (getWidth() == 1) { - SmallVector eidxs; - for (auto idx : idxs.slice(ignoreFirstSlicesOfDif)) { - eidxs.push_back((unsigned)cast(idx)->getZExtValue()); - } - return addToDiffe(val, extractMeta(BuilderM, dif, eidxs), BuilderM, - addingType, idxs, mask); - } else { - SmallVector res; - for (unsigned j = 0; j < getWidth(); j++) { - SmallVector lidxs; - SmallVector eidxs = {(unsigned)j}; - lidxs.push_back( - ConstantInt::get(Type::getInt32Ty(val->getContext()), j)); - for (auto idx : idxs.slice(ignoreFirstSlicesOfDif)) { - eidxs.push_back((unsigned)cast(idx)->getZExtValue()); - } - for (auto idx : idxs) { - lidxs.push_back(idx); - } - for (auto v : addToDiffe(val, extractMeta(BuilderM, dif, eidxs), - BuilderM, addingType, lidxs, mask)) - res.push_back(v); - } - return res; - } - } - if (auto ST = dyn_cast(VT)) { - auto SL = DL.getStructLayout(ST); - auto left_idx = SL->getElementContainingOffset(start); - auto right_idx = ST->getNumElements(); - if (storeSize != start + size) { - right_idx = SL->getElementContainingOffset(start + size); - // If this doesn't cleanly end the window, make sure we do a partial - // accumulate for the remaining part in right_idx. - if (SL->getElementOffset(right_idx) != start + size) - right_idx++; - } - SmallVector res; - for (auto i = left_idx; i < right_idx; i++) { - auto subType = ST->getElementType(i); - SmallVector lidxs(idxs.begin(), idxs.end()); - lidxs.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), i)); - auto sub_start = - (i == left_idx) ? (start - (unsigned)SL->getElementOffset(i)) : 0; - auto subTypeSize = (DL.getTypeSizeInBits(subType) + 7) / 8; - auto sub_end = (i == right_idx - 1) - ? min(start + size - (unsigned)SL->getElementOffset(i), - (unsigned)subTypeSize) - : subTypeSize; - for (auto v : - addToDiffe(val, dif, BuilderM, addingType, sub_start, - sub_end - sub_start, lidxs, mask, ignoreFirstSlicesOfDif)) - res.push_back(v); - } - return res; - } - - if (auto AT = dyn_cast(VT)) { - auto subType = AT->getElementType(); - auto subTypeSize = (DL.getTypeSizeInBits(subType) + 7) / 8; - auto left_idx = start / subTypeSize; - auto right_idx = AT->getNumElements(); - if (storeSize != start + size) { - right_idx = (start + size) / subTypeSize; - // If this doesn't cleanly end the window, make sure we do a partial - // accumulate for the remaining part in right_idx. - if (right_idx * subTypeSize != start + size) - right_idx++; - } - SmallVector res; - for (auto i = left_idx; i < right_idx; i++) { - SmallVector lidxs(idxs.begin(), idxs.end()); - lidxs.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), i)); - auto sub_start = (i == left_idx) ? (start - (i * subTypeSize)) : 0; - auto sub_end = (i == right_idx - 1) - ? min(start + size - (unsigned)(i * subTypeSize), - (unsigned)subTypeSize) - : subTypeSize; - for (auto v : - addToDiffe(val, dif, BuilderM, addingType, sub_start, - sub_end - sub_start, lidxs, mask, ignoreFirstSlicesOfDif)) - res.push_back(v); - } - return res; - } - - llvm::errs() << " VT: " << *VT << " idxs:{"; - for (auto idx : idxs) - llvm::errs() << *idx << ","; - llvm::errs() << "} start=" << start << " size=" << size - << " storeSize=" << storeSize << " val=" << *val << "\n"; - assert(0 && "unhandled accumulate with partial sizes"); - return {}; -} - -SmallVector -DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, - Type *addingType, ArrayRef idxs, - Value *mask) { - assert(mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ReverseModeCombined); - -#ifndef NDEBUG - if (auto arg = dyn_cast(val)) - assert(arg->getParent() == oldFunc); - if (auto inst = dyn_cast(val)) - assert(inst->getParent()->getParent() == oldFunc); -#endif - - SmallVector addedSelects; - - auto faddForNeg = [&](Value *old, Value *inc, bool san) { - if (auto bi = dyn_cast(inc)) { - if (auto ci = dyn_cast(bi->getOperand(0))) { - if (bi->getOpcode() == BinaryOperator::FSub && ci->isZero()) { - Value *res = BuilderM.CreateFSub(old, bi->getOperand(1)); - if (san) - res = SanitizeDerivatives(val, res, BuilderM, mask); - return res; - } - } - } - Value *res = BuilderM.CreateFAdd(old, inc); - if (san) - res = SanitizeDerivatives(val, res, BuilderM, mask); - return res; - }; - - auto faddForSelect = [&](Value *old, Value *dif) -> Value * { - //! optimize fadd of select to select of fadd - if (SelectInst *select = dyn_cast(dif)) { - if (Constant *ci = dyn_cast(select->getTrueValue())) { - if (ci->isZeroValue()) { - SelectInst *res = cast(BuilderM.CreateSelect( - select->getCondition(), old, - faddForNeg(old, select->getFalseValue(), false))); - addedSelects.push_back(res); - return SanitizeDerivatives(val, res, BuilderM, mask); - } - } - if (Constant *ci = dyn_cast(select->getFalseValue())) { - if (ci->isZeroValue()) { - SelectInst *res = cast(BuilderM.CreateSelect( - select->getCondition(), - faddForNeg(old, select->getTrueValue(), false), old)); - addedSelects.push_back(res); - return SanitizeDerivatives(val, res, BuilderM, mask); - } - } - } - - //! optimize fadd of bitcast select to select of bitcast fadd - if (BitCastInst *bc = dyn_cast(dif)) { - if (SelectInst *select = dyn_cast(bc->getOperand(0))) { - if (Constant *ci = dyn_cast(select->getTrueValue())) { - if (ci->isZeroValue()) { - SelectInst *res = cast(BuilderM.CreateSelect( - select->getCondition(), old, - faddForNeg(old, - BuilderM.CreateCast(bc->getOpcode(), - select->getFalseValue(), - bc->getDestTy()), - false))); - addedSelects.push_back(res); - return SanitizeDerivatives(val, res, BuilderM, mask); - } - } - if (Constant *ci = dyn_cast(select->getFalseValue())) { - if (ci->isZeroValue()) { - SelectInst *res = cast(BuilderM.CreateSelect( - select->getCondition(), - faddForNeg(old, - BuilderM.CreateCast(bc->getOpcode(), - select->getTrueValue(), - bc->getDestTy()), - false), - old)); - addedSelects.push_back(res); - return SanitizeDerivatives(val, res, BuilderM, mask); - } - } - } - } - - // fallback - return faddForNeg(old, dif, true); - }; - - if (val->getType()->isPointerTy()) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *val << "\n"; - } - if (isConstantValue(val)) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *val << "\n"; - } - assert(!val->getType()->isPointerTy()); - assert(!isConstantValue(val)); - - Value *ptr = getDifferential(val); - - Value *old; - if (idxs.size() != 0) { - SmallVector sv = { - ConstantInt::get(Type::getInt32Ty(val->getContext()), 0)}; - for (auto i : idxs) - sv.push_back(i); - ptr = BuilderM.CreateGEP(getShadowType(val->getType()), ptr, sv); - cast(ptr)->setIsInBounds(true); - old = BuilderM.CreateLoad( - GetElementPtrInst::getIndexedType(getShadowType(val->getType()), sv), - ptr); - } else { - old = BuilderM.CreateLoad(getShadowType(val->getType()), ptr); - } - if (dif->getType() != old->getType()) { - if (auto inst = dyn_cast(val)) { - EmitFailure("IllegalAddingType", inst->getDebugLoc(), inst, "val ", *val, - " dif ", *dif, " old ", *old); - return addedSelects; - } - llvm::errs() << " IllegalAddingType val: " << *val << " dif: " << *dif - << " old: " << *old << "\n"; - llvm_unreachable("IllegalAddingType"); - } - - assert(dif->getType() == old->getType()); - Value *res = nullptr; - if (old->getType()->isIntOrIntVectorTy() || old->getType()->isPointerTy()) { - if (!addingType) { - if (looseTypeAnalysis) { - if (old->getType()->isIntegerTy(64)) - addingType = Type::getDoubleTy(old->getContext()); - else if (old->getType()->isIntegerTy(32)) - addingType = Type::getFloatTy(old->getContext()); - } - } - if (!addingType) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "oldFunc: " << *oldFunc << "\n"; - ss << "Cannot deduce adding type of: " << *val << "\n"; - ss << " + idxs {"; - for (auto idx : idxs) - ss << *idx << ","; - ss << "}\n"; - if (auto inst = dyn_cast(val)) { - EmitNoTypeError(ss.str(), *inst, this, BuilderM); - return addedSelects; - } else if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType, - TR.analyzer, nullptr, wrap(&BuilderM)); - return addedSelects; - } else { - TR.dump(ss); - llvm::errs() << ss.str() << "\n"; - llvm_unreachable("Cannot deduce adding type"); - return addedSelects; - } - } - assert(addingType); - assert(addingType->isFPOrFPVectorTy()); - - auto oldBitSize = - oldFunc->getParent()->getDataLayout().getTypeSizeInBits(old->getType()); - auto newBitSize = - oldFunc->getParent()->getDataLayout().getTypeSizeInBits(addingType); - - if (oldBitSize == newBitSize) { - } else if (oldBitSize > newBitSize && oldBitSize % newBitSize == 0) { - if (!addingType->isVectorTy()) - addingType = - VectorType::get(addingType, oldBitSize / newBitSize, false); - } else { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "oldFunc: " << *oldFunc << "\n"; - ss << "Illegal intermediate when adding to: " << *val - << " with addingType: " << *addingType << "\n" - << " old: " << *old << " dif: " << *dif << "\n" - << " oldBitSize: " << oldBitSize << " newBitSize: " << newBitSize - << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType, - TR.analyzer, nullptr, wrap(&BuilderM)); - return addedSelects; - } else { - DebugLoc loc; - if (auto inst = dyn_cast(val)) - EmitFailure("CannotDeduceType", inst->getDebugLoc(), inst, ss.str()); - else { - llvm::errs() << ss.str() << "\n"; - llvm_unreachable("Cannot deduce adding type"); - } - return addedSelects; - } - } - - Value *bcold = old; - Value *bcdif = dif; - Type *intTy = nullptr; - if (old->getType()->isPointerTy()) { - auto &DL = oldFunc->getParent()->getDataLayout(); - intTy = Type::getIntNTy(old->getContext(), DL.getPointerSizeInBits()); - bcold = BuilderM.CreatePtrToInt(bcold, intTy); - bcdif = BuilderM.CreatePtrToInt(bcdif, intTy); - } else { - intTy = old->getType(); - } - - bcold = BuilderM.CreateBitCast(bcold, addingType); - bcdif = BuilderM.CreateBitCast(bcdif, addingType); - - res = faddForSelect(bcold, bcdif); - if (SelectInst *select = dyn_cast(res)) { - assert(addedSelects.back() == select); - addedSelects.erase(addedSelects.end() - 1); - - Value *tval = BuilderM.CreateBitCast(select->getTrueValue(), intTy); - Value *fval = BuilderM.CreateBitCast(select->getFalseValue(), intTy); - if (old->getType()->isPointerTy()) { - tval = BuilderM.CreateIntToPtr(tval, old->getType()); - fval = BuilderM.CreateIntToPtr(fval, old->getType()); - } - res = BuilderM.CreateSelect(select->getCondition(), tval, fval); - assert(select->getNumUses() == 0); - } else { - res = BuilderM.CreateBitCast(res, intTy); - if (old->getType()->isPointerTy()) - res = BuilderM.CreateIntToPtr(res, old->getType()); - } - if (!mask) { - BuilderM.CreateStore(res, ptr); - // store->setAlignment(align); - } else { - Type *tys[] = {res->getType(), ptr->getType()}; - auto F = getIntrinsicDeclaration(oldFunc->getParent(), - Intrinsic::masked_store, tys); - auto align = cast(ptr)->getAlign().value(); - assert(align); - Value *alignv = - ConstantInt::get(Type::getInt32Ty(mask->getContext()), align); - Value *args[] = {res, ptr, alignv, mask}; - BuilderM.CreateCall(F, args); - } - return addedSelects; - } else if (old->getType()->isFPOrFPVectorTy()) { - // TODO consider adding type - res = faddForSelect(old, dif); - - if (!mask) { - BuilderM.CreateStore(res, ptr); - // store->setAlignment(align); - } else { - Type *tys[] = {res->getType(), ptr->getType()}; - auto F = getIntrinsicDeclaration(oldFunc->getParent(), - Intrinsic::masked_store, tys); - auto align = cast(ptr)->getAlign().value(); - assert(align); - Value *alignv = - ConstantInt::get(Type::getInt32Ty(mask->getContext()), align); - Value *args[] = {res, ptr, alignv, mask}; - BuilderM.CreateCall(F, args); - } - return addedSelects; - } else if (auto st = dyn_cast(old->getType())) { - assert(!mask); - if (mask) - llvm_unreachable("cannot handle recursive addToDiffe with mask"); - for (unsigned i = 0; i < st->getNumElements(); ++i) { - // TODO pass in full type tree here and recurse into tree. - if (st->getElementType(i)->isPointerTy()) - continue; - if (st->getElementType(i)->isIntegerTy(8) || - st->getElementType(i)->isIntegerTy(1)) - continue; - Value *v = ConstantInt::get(Type::getInt32Ty(st->getContext()), i); - SmallVector idx2(idxs.begin(), idxs.end()); - idx2.push_back(v); - // FIXME: reconsider if passing a nullptr is correct here. - auto selects = addToDiffe(val, extractMeta(BuilderM, dif, i), BuilderM, - nullptr, idx2); - for (auto select : selects) { - addedSelects.push_back(select); - } - } - return addedSelects; - } else if (auto at = dyn_cast(old->getType())) { - assert(!mask); - if (mask) - llvm_unreachable("cannot handle recursive addToDiffe with mask"); - if (at->getElementType()->isPointerTy()) - return addedSelects; - for (unsigned i = 0; i < at->getNumElements(); ++i) { - // TODO pass in full type tree here and recurse into tree. - Value *v = ConstantInt::get(Type::getInt32Ty(at->getContext()), i); - SmallVector idx2(idxs.begin(), idxs.end()); - idx2.push_back(v); - auto selects = addToDiffe(val, extractMeta(BuilderM, dif, i), BuilderM, - addingType, idx2); - for (auto select : selects) { - addedSelects.push_back(select); - } - } - return addedSelects; - } else { - llvm::errs() << " idx: {"; - for (auto i : idxs) - llvm::errs() << *i << ", "; - llvm::errs() << "}\n"; - if (addingType) - llvm::errs() << " addingType: " << *addingType << "\n"; - else - llvm::errs() << " addingType: null\n"; - llvm::errs() << " oldType:" << *old->getType() << " old:" << *old << "\n"; - llvm_unreachable("unknown type to add to diffe"); - exit(1); - } -} - -void DiffeGradientUtils::setDiffe(Value *val, Value *toset, - IRBuilder<> &BuilderM) { -#ifndef NDEBUG - if (auto arg = dyn_cast(val)) - assert(arg->getParent() == oldFunc); - if (auto inst = dyn_cast(val)) - assert(inst->getParent()->getParent() == oldFunc); - if (isConstantValue(val)) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *val << "\n"; - } - assert(!isConstantValue(val)); -#endif - toset = SanitizeDerivatives(val, toset, BuilderM); - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit || - mode == DerivativeMode::ForwardModeError) { - assert(getShadowType(val->getType()) == toset->getType()); - auto found = invertedPointers.find(val); - assert(found != invertedPointers.end()); - auto placeholder0 = &*found->second; - auto placeholder = cast(placeholder0); - invertedPointers.erase(found); - replaceAWithB(placeholder, toset); - placeholder->replaceAllUsesWith(toset); - erase(placeholder); - invertedPointers.insert( - std::make_pair((const Value *)val, InvertedPointerVH(this, toset))); - return; - } - Value *tostore = getDifferential(val); -#if LLVM_VERSION_MAJOR < 17 - if (toset->getContext().supportsTypedPointers()) { - if (toset->getType() != tostore->getType()->getPointerElementType()) { - llvm::errs() << "toset:" << *toset << "\n"; - llvm::errs() << "tostore:" << *tostore << "\n"; - } - assert(toset->getType() == tostore->getType()->getPointerElementType()); - } -#endif - BuilderM.CreateStore(toset, tostore); -} - -CallInst *DiffeGradientUtils::freeCache(BasicBlock *forwardPreheader, - const SubLimitType &sublimits, int i, - AllocaInst *alloc, - ConstantInt *byteSizeOfType, - Value *storeInto, MDNode *InvariantMD) { - if (!FreeMemory) - return nullptr; - assert(reverseBlocks.find(forwardPreheader) != reverseBlocks.end()); - assert(reverseBlocks[forwardPreheader].size()); - IRBuilder<> tbuild(reverseBlocks[forwardPreheader].back()); - tbuild.setFastMathFlags(getFast()); - - // ensure we are before the terminator if it exists - if (tbuild.GetInsertBlock()->size() && - tbuild.GetInsertBlock()->getTerminator()) { - tbuild.SetInsertPoint(tbuild.GetInsertBlock()->getTerminator()); - } - - ValueToValueMapTy antimap; - for (int j = sublimits.size() - 1; j >= i; j--) { - auto &innercontainedloops = sublimits[j].second; - for (auto riter = innercontainedloops.rbegin(), - rend = innercontainedloops.rend(); - riter != rend; ++riter) { - const auto &idx = riter->first; - if (idx.var) { - antimap[idx.var] = - tbuild.CreateLoad(idx.var->getType(), idx.antivaralloc); - } - } - } - - Value *metaforfree = unwrapM(storeInto, tbuild, antimap, - UnwrapMode::AttemptFullUnwrapWithLookup); - Type *T; -#if LLVM_VERSION_MAJOR < 17 - if (metaforfree->getContext().supportsTypedPointers()) { - T = metaforfree->getType()->getPointerElementType(); - } else { - T = PointerType::getUnqual(metaforfree->getContext()); - } -#else - T = PointerType::getUnqual(metaforfree->getContext()); -#endif - LoadInst *forfree = cast(tbuild.CreateLoad(T, metaforfree)); - forfree->setMetadata(LLVMContext::MD_invariant_group, InvariantMD); - forfree->setMetadata(LLVMContext::MD_dereferenceable, - MDNode::get(forfree->getContext(), - ArrayRef(ConstantAsMetadata::get( - byteSizeOfType)))); - forfree->setName("forfree"); - unsigned align = getCacheAlignment( - (unsigned)newFunc->getParent()->getDataLayout().getPointerSize()); - forfree->setAlignment(Align(align)); - - CallInst *ci = CreateDealloc(tbuild, forfree); - if (ci) { - if (newFunc->getSubprogram()) - ci->setDebugLoc(DILocation::get(newFunc->getContext(), 0, 0, - newFunc->getSubprogram(), 0)); - scopeFrees[alloc].insert(ci); - } - return ci; -} - -void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, - Value *origVal, Type *addingType, - unsigned start, unsigned size, - Value *origptr, Value *dif, - IRBuilder<> &BuilderM, - MaybeAlign align, Value *mask) { - auto &DL = oldFunc->getParent()->getDataLayout(); - - auto addingSize = (DL.getTypeSizeInBits(addingType) + 1) / 8; - if (addingSize != size) { - assert(size > addingSize); - addingType = - VectorType::get(addingType, size / addingSize, /*isScalable*/ false); - size = (size / addingSize) * addingSize; - } - - Value *ptr; - - switch (mode) { - case DerivativeMode::ForwardModeSplit: - case DerivativeMode::ForwardMode: - case DerivativeMode::ForwardModeError: - ptr = invertPointerM(origptr, BuilderM); - break; - case DerivativeMode::ReverseModePrimal: - assert(false && "Invalid derivative mode (ReverseModePrimal)"); - break; - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModeCombined: - ptr = lookupM(invertPointerM(origptr, BuilderM), BuilderM); - break; - } - - bool needsCast = false; -#if LLVM_VERSION_MAJOR < 17 - if (origptr->getContext().supportsTypedPointers()) { - needsCast = origptr->getType()->getPointerElementType() != addingType; - } -#endif - - assert(ptr); - if (start != 0 || needsCast) { - auto rule = [&](Value *ptr) { - if (start != 0) { - auto i8 = Type::getInt8Ty(ptr->getContext()); - ptr = BuilderM.CreatePointerCast( - ptr, PointerType::get( - i8, cast(ptr->getType())->getAddressSpace())); - auto off = ConstantInt::get(Type::getInt64Ty(ptr->getContext()), start); - ptr = BuilderM.CreateInBoundsGEP(i8, ptr, off); - } - if (needsCast) { - ptr = BuilderM.CreatePointerCast( - ptr, PointerType::get( - addingType, - cast(ptr->getType())->getAddressSpace())); - } - return ptr; - }; - ptr = applyChainRule( - PointerType::get( - addingType, - cast(origptr->getType())->getAddressSpace()), - BuilderM, rule, ptr); - } - - if (getWidth() == 1) - needsCast = dif->getType() != addingType; - else if (auto AT = cast(dif->getType())) - needsCast = AT->getElementType() != addingType; - else - needsCast = - cast(dif->getType())->getElementType() != addingType; - - if (start != 0 || needsCast) { - auto rule = [&](Value *dif) { - if (start != 0) { - IRBuilder<> A(inversionAllocs); - auto i8 = Type::getInt8Ty(ptr->getContext()); - auto prevSize = (DL.getTypeSizeInBits(dif->getType()) + 1) / 8; - Type *tys[] = {ArrayType::get(i8, start), addingType, - ArrayType::get(i8, prevSize - start - size)}; - auto ST = StructType::get(i8->getContext(), tys, /*isPacked*/ true); - auto Al = A.CreateAlloca(ST); - BuilderM.CreateStore(dif, - BuilderM.CreatePointerCast( - Al, PointerType::getUnqual(dif->getType()))); - Value *idxs[] = { - ConstantInt::get(Type::getInt64Ty(ptr->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(ptr->getContext()), 1)}; - - auto difp = BuilderM.CreateInBoundsGEP(ST, Al, idxs); - dif = BuilderM.CreateLoad(addingType, difp); - } - if (dif->getType() != addingType) { - auto difSize = (DL.getTypeSizeInBits(dif->getType()) + 1) / 8; - if (difSize < size) { - llvm::errs() << " ds: " << difSize << " as: " << size << "\n"; - llvm::errs() << " dif: " << *dif << " adding: " << *addingType - << "\n"; - } - assert(difSize >= size); - if (CastInst::castIsValid(Instruction::CastOps::BitCast, dif, - addingType)) - dif = BuilderM.CreateBitCast(dif, addingType); - else { - IRBuilder<> A(inversionAllocs); - auto Al = A.CreateAlloca(addingType); - BuilderM.CreateStore(dif, - BuilderM.CreatePointerCast( - Al, PointerType::getUnqual(dif->getType()))); - dif = BuilderM.CreateLoad(addingType, Al); - } - } - return dif; - }; - dif = applyChainRule(addingType, BuilderM, rule, dif); - } - - auto TmpOrig = getBaseObject(origptr); - - // atomics - bool Atomic = AtomicAdd; - auto Arch = llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); - - // No need to do atomic on local memory for CUDA since it can't be raced - // upon - if (isa(TmpOrig) && - (Arch == Triple::nvptx || Arch == Triple::nvptx64 || - Arch == Triple::amdgcn)) { - Atomic = false; - } - // Moreover no need to do atomic on local shadows regardless since they are - // not captured/escaping and created in this function. This assumes that - // all additional parallelism in this function is outlined. - if (backwardsOnlyShadows.find(TmpOrig) != backwardsOnlyShadows.end()) - Atomic = false; - - if (Atomic) { - // For amdgcn constant AS is 4 and if the primal is in it we need to cast - // the derivative value to AS 1 - if (Arch == Triple::amdgcn && - cast(origptr->getType())->getAddressSpace() == 4) { - auto rule = [&](Value *ptr) { - return BuilderM.CreateAddrSpaceCast(ptr, - PointerType::get(addingType, 1)); - }; - ptr = - applyChainRule(PointerType::get(addingType, 1), BuilderM, rule, ptr); - } - - if (mask) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "Unimplemented masked atomic fadd for ptr:" << *ptr - << " dif:" << *dif << " mask: " << *mask << " orig: " << *orig << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(orig), - ErrorType::NoDerivative, this, nullptr, - wrap(&BuilderM)); - return; - } else { - EmitFailure("NoDerivative", orig->getDebugLoc(), orig, ss.str()); - return; - } - } - - /* - while (auto ASC = dyn_cast(ptr)) { - ptr = ASC->getOperand(0); - } - while (auto ASC = dyn_cast(ptr)) { - if (!ASC->isCast()) break; - if (ASC->getOpcode() != Instruction::AddrSpaceCast) break; - ptr = ASC->getOperand(0); - } - */ - AtomicRMWInst::BinOp op = AtomicRMWInst::FAdd; - if (auto vt = dyn_cast(addingType)) { - assert(!vt->getElementCount().isScalable()); - size_t numElems = vt->getElementCount().getKnownMinValue(); - auto rule = [&](Value *dif, Value *ptr) { - for (size_t i = 0; i < numElems; ++i) { - auto vdif = BuilderM.CreateExtractElement(dif, i); - vdif = SanitizeDerivatives(orig, vdif, BuilderM); - Value *Idxs[] = { - ConstantInt::get(Type::getInt64Ty(vt->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(vt->getContext()), i)}; - auto vptr = BuilderM.CreateGEP(addingType, ptr, Idxs); - MaybeAlign alignv = align; - if (alignv) { - if (start != 0) { - // todo make better alignment calculation - assert((*alignv).value() != 0); - if (start % (*alignv).value() != 0) { - alignv = Align(1); - } - } - } - BuilderM.CreateAtomicRMW(op, vptr, vdif, alignv, - AtomicOrdering::Monotonic, - SyncScope::System); - } - }; - applyChainRule(BuilderM, rule, dif, ptr); - } else { - auto rule = [&](Value *dif, Value *ptr) { - dif = SanitizeDerivatives(orig, dif, BuilderM); - MaybeAlign alignv = align; - if (alignv) { - if (start != 0) { - // todo make better alignment calculation - assert((*alignv).value() != 0); - if (start % (*alignv).value() != 0) { - alignv = Align(1); - } - } - } - BuilderM.CreateAtomicRMW(op, ptr, dif, alignv, - AtomicOrdering::Monotonic, SyncScope::System); - }; - applyChainRule(BuilderM, rule, dif, ptr); - } - return; - } - - if (!mask) { - - size_t idx = 0; - auto rule = [&](Value *ptr, Value *dif) { - auto LI = BuilderM.CreateLoad(addingType, ptr); - - Value *res = BuilderM.CreateFAdd(LI, dif); - res = SanitizeDerivatives(orig, res, BuilderM); - StoreInst *st = BuilderM.CreateStore(res, ptr); - - SmallVector scopeMD = { - getDerivativeAliasScope(origptr, idx)}; - if (auto origValI = dyn_cast_or_null(origVal)) - if (auto MD = origValI->getMetadata(LLVMContext::MD_alias_scope)) { - auto MDN = cast(MD); - for (auto &o : MDN->operands()) - scopeMD.push_back(o); - } - auto scope = MDNode::get(LI->getContext(), scopeMD); - LI->setMetadata(LLVMContext::MD_alias_scope, scope); - st->setMetadata(LLVMContext::MD_alias_scope, scope); - - SmallVector MDs; - for (ssize_t j = -1; j < getWidth(); j++) { - if (j != (ssize_t)idx) - MDs.push_back(getDerivativeAliasScope(origptr, j)); - } - if (auto origValI = dyn_cast_or_null(origVal)) - if (auto MD = origValI->getMetadata(LLVMContext::MD_noalias)) { - auto MDN = cast(MD); - for (auto &o : MDN->operands()) - MDs.push_back(o); - } - idx++; - auto noscope = MDNode::get(ptr->getContext(), MDs); - LI->setMetadata(LLVMContext::MD_noalias, noscope); - st->setMetadata(LLVMContext::MD_noalias, noscope); - - if (origVal && isa(origVal) && start == 0 && - size == (DL.getTypeSizeInBits(origVal->getType()) + 7) / 8) { - auto origValI = cast(origVal); - LI->copyMetadata(*origValI, MD_ToCopy); - unsigned int StoreData[] = {LLVMContext::MD_tbaa, - LLVMContext::MD_tbaa_struct}; - for (auto MD : StoreData) - st->setMetadata(MD, origValI->getMetadata(MD)); - } - - LI->setDebugLoc(getNewFromOriginal(orig->getDebugLoc())); - st->setDebugLoc(getNewFromOriginal(orig->getDebugLoc())); - - if (align) { - auto alignv = align ? (*align).value() : 0; - if (alignv != 0) { - if (start != 0) { - // todo make better alignment calculation - if (start % alignv != 0) { - alignv = 1; - } - } - - LI->setAlignment(Align(alignv)); - st->setAlignment(Align(alignv)); - } - } - }; - applyChainRule(BuilderM, rule, ptr, dif); - } else { - Type *tys[] = {addingType, origptr->getType()}; - auto LF = getIntrinsicDeclaration(oldFunc->getParent(), - Intrinsic::masked_load, tys); - auto SF = getIntrinsicDeclaration(oldFunc->getParent(), - Intrinsic::masked_store, tys); - unsigned aligni = align ? align->value() : 0; - - if (aligni != 0) - if (start != 0) { - // todo make better alignment calculation - if (start % aligni != 0) { - aligni = 1; - } - } - Value *alignv = - ConstantInt::get(Type::getInt32Ty(mask->getContext()), aligni); - auto rule = [&](Value *ptr, Value *dif) { - Value *largs[] = {ptr, alignv, mask, - Constant::getNullValue(dif->getType())}; - Value *LI = BuilderM.CreateCall(LF, largs); - Value *res = BuilderM.CreateFAdd(LI, dif); - res = SanitizeDerivatives(orig, res, BuilderM, mask); - Value *sargs[] = {res, ptr, alignv, mask}; - BuilderM.CreateCall(SF, sargs); - }; - applyChainRule(BuilderM, rule, ptr, dif); - } -} - -void DiffeGradientUtils::addToInvertedPtrDiffe( - llvm::Instruction *orig, llvm::Value *origVal, TypeTree vd, - unsigned LoadSize, llvm::Value *origptr, llvm::Value *prediff, - llvm::IRBuilder<> &Builder2, MaybeAlign alignment, llvm::Value *premask) - -{ - - unsigned start = 0; - unsigned size = LoadSize; - - assert(prediff); - - BasicBlock *merge = nullptr; - - while (1) { - unsigned nextStart = size; - - auto dt = vd[{-1}]; - for (size_t i = start; i < size; ++i) { - bool Legal = true; - dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal); - if (!Legal) { - nextStart = i; - break; - } - } - if (!dt.isKnown()) { - TR.dump(); - llvm::errs() << " vd:" << vd.str() << " start:" << start - << " size: " << size << " dt:" << dt.str() << "\n"; - } - assert(dt.isKnown()); - - if (Type *isfloat = dt.isFloat()) { - - if (origVal) { - if (start == 0 && nextStart == LoadSize) { - setDiffe(origVal, - Constant::getNullValue(getShadowType(origVal->getType())), - Builder2); - } else { - Value *tostore = getDifferential(origVal); - - auto i8 = Type::getInt8Ty(tostore->getContext()); - if (start != 0) { - tostore = Builder2.CreatePointerCast( - tostore, - PointerType::get( - i8, - cast(tostore->getType())->getAddressSpace())); - auto off = ConstantInt::get(Type::getInt64Ty(tostore->getContext()), - start); - tostore = Builder2.CreateInBoundsGEP(i8, tostore, off); - } - auto AT = ArrayType::get(i8, nextStart - start); - tostore = Builder2.CreatePointerCast( - tostore, - PointerType::get( - AT, - cast(tostore->getType())->getAddressSpace())); - Builder2.CreateStore(Constant::getNullValue(AT), tostore); - } - } - - if (!isConstantValue(origptr)) { - auto basePtr = getBaseObject(origptr); - assert(!isConstantValue(basePtr)); - // If runtime activity, first see if we can prove that the shadow/primal - // are distinct statically as they are allocas/mallocs, if not compare - // the pointers and conditionally execute. - if ((!isa(basePtr) && !isAllocationCall(basePtr, TLI)) && - runtimeActivity && !merge) { - Value *primal_val = lookupM(getNewFromOriginal(origptr), Builder2); - Value *shadow_val = - lookupM(invertPointerM(origptr, Builder2), Builder2); - if (getWidth() != 1) { - shadow_val = extractMeta(Builder2, shadow_val, 0); - } - Value *shadow = Builder2.CreateICmpNE(primal_val, shadow_val); - - BasicBlock *current = Builder2.GetInsertBlock(); - BasicBlock *conditional = - addReverseBlock(current, current->getName() + "_active"); - merge = addReverseBlock(conditional, current->getName() + "_amerge"); - Builder2.CreateCondBr(shadow, conditional, merge); - Builder2.SetInsertPoint(conditional); - } - // Masked partial type is unhanled. - if (premask) - assert(start == 0 && nextStart == LoadSize); - addToInvertedPtrDiffe(orig, origVal, isfloat, start, nextStart - start, - origptr, prediff, Builder2, alignment, premask); - } - } - - if (nextStart == size) - break; - start = nextStart; - } - if (merge) { - Builder2.CreateBr(merge); - Builder2.SetInsertPoint(merge); - } -} diff --git a/enzyme/Enzyme/DiffeGradientUtils.h b/enzyme/Enzyme/DiffeGradientUtils.h deleted file mode 100644 index d4fac79827af..000000000000 --- a/enzyme/Enzyme/DiffeGradientUtils.h +++ /dev/null @@ -1,129 +0,0 @@ -//===- DiffeGradientUtils.h - Helper class and utilities for AD ---------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file declares two helper classes GradientUtils and subclass -// DiffeGradientUtils. These classes contain utilities for managing the cache, -// recomputing statements, and in the case of DiffeGradientUtils, managing -// adjoint values and shadow pointers. -// -//===----------------------------------------------------------------------===// - -#ifndef ENZYME_DIFFEGRADIENTUTILS_H_ -#define ENZYME_DIFFEGRADIENTUTILS_H_ - -#include "GradientUtils.h" - -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Metadata.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/PostDominators.h" -#include "llvm/Analysis/ValueTracking.h" - -#include "ActivityAnalysis.h" -#include "EnzymeLogic.h" -#include "Utils.h" - -#include "llvm-c/Core.h" - -#if LLVM_VERSION_MAJOR <= 16 -#include "llvm/ADT/Triple.h" -#endif - -class DiffeGradientUtils final : public GradientUtils { - DiffeGradientUtils( - EnzymeLogic &Logic, llvm::Function *newFunc_, llvm::Function *oldFunc_, - llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, TypeResults TR, - llvm::ValueToValueMapTy &invertedPointers_, - const llvm::SmallPtrSetImpl &constantvalues_, - const llvm::SmallPtrSetImpl &returnvals_, - DIFFE_TYPE ActiveReturn, bool shadowReturnUsed, - llvm::ArrayRef constant_values, - llvm::ValueMap &origToNew_, - DerivativeMode mode, bool runtimeActivity, unsigned width, bool omp); - -public: - /// Whether to free memory in reverse pass or split forward. - bool FreeMemory; - llvm::ValueMap> - differentials; - static DiffeGradientUtils * - CreateFromClone(EnzymeLogic &Logic, DerivativeMode mode, bool runtimeActivity, - unsigned width, llvm::Function *todiff, - llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, - FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, - bool shadowReturnArg, bool diffeReturnArg, - llvm::ArrayRef constant_args, - ReturnType returnValue, llvm::Type *additionalArg, bool omp); - - llvm::AllocaInst *getDifferential(llvm::Value *val); - -public: - llvm::Value *diffe(llvm::Value *val, llvm::IRBuilder<> &BuilderM); - - /// Returns created select instructions, if any - llvm::SmallVector - addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, - llvm::Type *addingType, llvm::ArrayRef idxs = {}, - llvm::Value *mask = nullptr); - - /// Returns created select instructions, if any - llvm::SmallVector - addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, - llvm::Type *addingType, unsigned start, unsigned size, - llvm::ArrayRef idxs = {}, - llvm::Value *mask = nullptr, size_t ignoreFirstSlicesToDiff = 0); - - void setDiffe(llvm::Value *val, llvm::Value *toset, - llvm::IRBuilder<> &BuilderM); - - llvm::CallInst * - freeCache(llvm::BasicBlock *forwardPreheader, const SubLimitType &sublimits, - int i, llvm::AllocaInst *alloc, llvm::ConstantInt *byteSizeOfType, - llvm::Value *storeInto, llvm::MDNode *InvariantMD) override; - - /// align is the alignment that should be specified for load/store to pointer - void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal, - llvm::Type *addingType, unsigned start, - unsigned size, llvm::Value *origptr, - llvm::Value *dif, llvm::IRBuilder<> &BuilderM, - llvm::MaybeAlign align = llvm::MaybeAlign(), - llvm::Value *mask = nullptr); - - void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal, - TypeTree vd, unsigned size, llvm::Value *origptr, - llvm::Value *prediff, llvm::IRBuilder<> &Builder2, - llvm::MaybeAlign align = llvm::MaybeAlign(), - llvm::Value *premask = nullptr); -}; - -#endif diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp deleted file mode 100644 index 3d3722cfe634..000000000000 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ /dev/null @@ -1,1160 +0,0 @@ -//===- DifferentialUseAnalysis.cpp - Determine values needed in reverse -// pass-===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the declaration of Differential USe Analysis -- an -// AD-specific analysis that deduces if a given value is needed in the reverse -// pass. -// -//===----------------------------------------------------------------------===// - -#include -#include -#include - -#include "DifferentialUseAnalysis.h" -#include "Utils.h" - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/IntrinsicsX86.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" - -#include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" - -#include "DiffeGradientUtils.h" -#include "GradientUtils.h" -#include "LibraryFuncs.h" - -using namespace llvm; - -StringMap> - customDiffUseHandlers; - -bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( - const GradientUtils *gutils, const Value *val, DerivativeMode mode, - const Instruction *user, - const SmallPtrSetImpl &oldUnreachable, QueryType qtype, - bool *recursiveUse) { - TypeResults const &TR = gutils->TR; -#ifndef NDEBUG - if (auto ainst = dyn_cast(val)) { - assert(ainst->getParent()->getParent() == gutils->oldFunc); - } -#endif - - bool shadow = - qtype == QueryType::Shadow || qtype == QueryType::ShadowByConstPrimal; - - /// Recursive use is only usable in shadow mode. - if (!shadow) - assert(recursiveUse == nullptr); - else - assert(recursiveUse != nullptr); - - if (!shadow && isPointerArithmeticInst(user, /*includephi*/ true, - /*includebin*/ false)) { - return false; - } - - // Floating point numbers cannot be used as a shadow pointer/etc - if (qtype == QueryType::ShadowByConstPrimal) - if (TR.query(const_cast(val))[{-1}].isFloat()) - return false; - - if (!user) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: of " << *val << " in reverse as nullptr user\n"; - return true; - } - - assert(user->getParent()->getParent() == gutils->oldFunc); - - if (oldUnreachable.count(user->getParent())) - return false; - - if (auto SI = dyn_cast(user)) { - if (!shadow) { - - // We don't need any of the input operands to compute the adjoint of a - // store instance The one exception to this is stores to the loop bounds. - if (SI->getValueOperand() == val) { - for (auto U : SI->getPointerOperand()->users()) { - if (auto CI = dyn_cast(U)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "__kmpc_for_static_init_4" || - F->getName() == "__kmpc_for_static_init_4u" || - F->getName() == "__kmpc_for_static_init_8" || - F->getName() == "__kmpc_for_static_init_8u") { - if (CI->getArgOperand(4) == val || - CI->getArgOperand(5) == val || CI->getArgOperand(6)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from omp " << *user << "\n"; - return true; - } - } - } - } - } - } - } else { - bool backwardsShadow = false; - bool forwardsShadow = true; - for (auto pair : gutils->backwardsOnlyShadows) { - if (pair.second.stores.count(SI) && - !gutils->isConstantValue(pair.first)) { - backwardsShadow = true; - forwardsShadow = pair.second.primalInitialize; - } - } - - // Preserve any non-floating point values that are stored in an active - // backwards creation shadow. - - if (SI->getValueOperand() == val) { - // storing an active pointer into a location - // doesn't require the shadow pointer for the - // reverse pass - // Unless the store is into a backwards store, which would - // would then be performed in the reverse if the stored value was - // a possible pointer. - - if (!((mode == DerivativeMode::ReverseModePrimal && forwardsShadow) || - (mode == DerivativeMode::ReverseModeGradient && - backwardsShadow) || - (mode == DerivativeMode::ForwardModeSplit && backwardsShadow) || - (mode == DerivativeMode::ReverseModeCombined && - (forwardsShadow || backwardsShadow)) || - mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError)) - return false; - } else { - // Likewise, if not rematerializing in reverse pass, you - // don't need to keep the pointer operand for known pointers - - auto ct = TR.query(const_cast(SI->getValueOperand()))[{-1}]; - if (ct == BaseType::Pointer || ct == BaseType::Integer) { - - if (!((mode == DerivativeMode::ReverseModePrimal && forwardsShadow) || - (mode == DerivativeMode::ReverseModeGradient && - backwardsShadow) || - (mode == DerivativeMode::ForwardModeSplit && backwardsShadow) || - (mode == DerivativeMode::ReverseModeCombined && - (forwardsShadow || backwardsShadow)) || - mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError)) - return false; - } - } - - if (!gutils->isConstantValue( - const_cast(SI->getPointerOperand()))) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow of " << *val - << " in reverse as shadow store " << *SI << "\n"; - return true; - } else - return false; - } - return false; - } - - if (!shadow) - if (auto LI = dyn_cast(user)) { - if (gutils->runtimeActivity) { - auto vd = TR.query(const_cast(user)); - if (!vd.isKnown()) { - auto ET = LI->getType(); - // It verbatim needs to replicate the same behavior as - // adjointgenerator. From reverse mode type analysis - // (https://github.com/EnzymeAD/Enzyme/blob/194875cbccd73d63cacfefbfa85c1f583c2fa1fe/enzyme/Enzyme/AdjointGenerator.h#L556) - if (looseTypeAnalysis || true) { - vd = defaultTypeTreeForLLVM(ET, const_cast(LI)); - } - } - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - auto LoadSize = (DL.getTypeSizeInBits(LI->getType()) + 1) / 8; - bool hasFloat = true; - for (ssize_t i = -1; i < (ssize_t)LoadSize; ++i) { - if (vd[{(int)i}].isFloat()) { - hasFloat = true; - break; - } - } - if (hasFloat && !gutils->isConstantInstruction( - const_cast(user))) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from runtime active load " << *user - << "\n"; - return true; - } - } - return false; - } - - if (auto MTI = dyn_cast(user)) { - // If memtransfer, only the primal of the size is needed reverse pass - if (!shadow) { - // Unless we're storing into a backwards only shadow store - if (MTI->getArgOperand(1) == val || MTI->getArgOperand(2) == val) { - for (auto pair : gutils->backwardsOnlyShadows) - if (pair.second.stores.count(MTI)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from remat memtransfer " << *user - << "\n"; - return true; - } - } - if (MTI->getArgOperand(2) != val) - return false; - bool res = !gutils->isConstantValue(MTI->getArgOperand(0)); - if (res) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from memtransfer " << *user << "\n"; - } - return res; - } else { - - if (MTI->getArgOperand(0) != val && MTI->getArgOperand(1) != val) - return false; - - if (!gutils->isConstantValue( - const_cast(MTI->getArgOperand(0)))) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow of " << *val - << " in reverse as shadow MTI " << *MTI << "\n"; - return true; - } else - return false; - } - } - - if (auto MS = dyn_cast(user)) { - if (!shadow) { - // Preserve the primal of length of memsets of backward creation shadows, - // or if float-like and non constant value. - if (MS->getArgOperand(1) == val || MS->getArgOperand(2) == val) { - for (auto pair : gutils->backwardsOnlyShadows) - if (pair.second.stores.count(MS)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from remat memset " << *user << "\n"; - return true; - } - bool res = !gutils->isConstantValue(MS->getArgOperand(0)); - if (res) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from memset " << *user << "\n"; - } - return res; - } - } else { - - if (MS->getArgOperand(0) != val) - return false; - - if (!gutils->isConstantValue(const_cast(MS->getArgOperand(0)))) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow of " << *val - << " in reverse as shadow MS " << *MS << "\n"; - return true; - } else - return false; - } - } - - if (!shadow) - if (isa(user) || isa(user) || isa(user) || - isa(user) || isa(user) - // isa(use) || - // isa(use) || isa(use) || - // isa(use) || isa(use) - // || isa(use) - ) { - return false; - } - - if (!shadow) - if (auto IEI = dyn_cast(user)) { - // Only need the index in the reverse, so if the value is not - // the index, short circuit and say we don't need - if (IEI->getOperand(2) != val) { - return false; - } - // The index is only needed in the reverse if the value being inserted - // is a possible active floating point value - if (gutils->isConstantValue(const_cast(IEI)) || - TR.query(const_cast(IEI))[{-1}] == - BaseType::Pointer) - return false; - // Otherwise, we need the value. - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from non-pointer insertelem " << *user - << " " - << TR.query(const_cast(IEI)).str() - << "\n"; - return true; - } - - if (!shadow) - if (auto EEI = dyn_cast(user)) { - // Only need the index in the reverse, so if the value is not - // the index, short circuit and say we don't need - if (EEI->getIndexOperand() != val) { - return false; - } - // The index is only needed in the reverse if the value being inserted - // is a possible active floating point value - if (gutils->isConstantValue(const_cast(EEI)) || - TR.query(const_cast(EEI))[{-1}] == - BaseType::Pointer) - return false; - // Otherwise, we need the value. - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from non-pointer extractelem " << *user - << " " - << TR.query(const_cast(EEI)).str() - << "\n"; - return true; - } - - if (!shadow) - if (auto IVI = dyn_cast(user)) { - // Only need the index in the reverse, so if the value is not - // the index, short circuit and say we don't need - bool valueIsIndex = false; - for (unsigned i = 2; i < IVI->getNumOperands(); ++i) { - if (IVI->getOperand(i) == val) { - valueIsIndex = true; - } - } - - if (!valueIsIndex) - return false; - - // The index is only needed in the reverse if the value being inserted - // is a possible active floating point value - if (gutils->isConstantValue(const_cast(IVI)) || - TR.query(const_cast(IVI))[{-1}] == - BaseType::Pointer) - return false; - // Otherwise, we need the value. - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from non-pointer insertval " << *user - << " " - << TR.query(const_cast(IVI)).str() - << "\n"; - return true; - } - - if (!shadow) - if (auto EVI = dyn_cast(user)) { - // Only need the index in the reverse, so if the value is not - // the index, short circuit and say we don't need - bool valueIsIndex = false; - for (unsigned i = 2; i < EVI->getNumOperands(); ++i) { - if (EVI->getOperand(i) == val) { - valueIsIndex = true; - } - } - - if (!valueIsIndex) - return false; - - // The index is only needed in the reverse if the value being inserted - // is a possible active floating point value - if (gutils->isConstantValue(const_cast(EVI)) || - TR.query(const_cast(EVI))[{-1}] == - BaseType::Pointer) - return false; - // Otherwise, we need the value. - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from non-pointer extractval " << *user - << " " - << TR.query(const_cast(EVI)).str() - << "\n"; - return true; - } - - Intrinsic::ID ID = Intrinsic::not_intrinsic; - if (auto II = dyn_cast(user)) { - ID = II->getIntrinsicID(); - } else if (auto CI = dyn_cast(user)) { - StringRef funcName = getFuncNameFromCall(const_cast(CI)); - isMemFreeLibMFunction(funcName, &ID); - } - - if (ID != Intrinsic::not_intrinsic) { - if (ID == Intrinsic::lifetime_start || ID == Intrinsic::lifetime_end || - ID == Intrinsic::stacksave || ID == Intrinsic::stackrestore) { - return false; - } - } - - if (!shadow) - if (auto si = dyn_cast(user)) { - // Only would potentially need the condition - if (si->getCondition() != val) { - return false; - } - - // only need the condition if select is active - bool needed = !gutils->isConstantValue(const_cast(si)); - if (needed) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from select " << *user << "\n"; - } - return needed; - } - -#include "BlasDiffUse.inc" - - if (auto CI = dyn_cast(user)) { - - { - SmallVector OrigDefs; - CI->getOperandBundlesAsDefs(OrigDefs); - SmallVector Defs; - for (auto bund : OrigDefs) { - for (auto inp : bund.inputs()) { - if (inp == val) - return true; - } - } - } - - auto funcName = getFuncNameFromCall(CI); - - { - auto found = customDiffUseHandlers.find(funcName); - if (found != customDiffUseHandlers.end()) { - bool useDefault = false; - bool result = found->second(CI, gutils, val, shadow, mode, useDefault); - if (!useDefault) { - if (result) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(qtype) << " of " << *val - << " from custom diff use handler of " << *CI - << "\n"; - } - return result; - } - } - } - - // Don't need shadow inputs for alloc function - if (shadow && isAllocationFunction(funcName, gutils->TLI)) - return false; - - // Even though inactive, keep the shadow pointer around in forward mode - // to perform the same memory free behavior on the shadow. - if (shadow && - (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError) && - isDeallocationFunction(funcName, gutils->TLI)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow of " << *val - << " in reverse as shadow free " << *CI << "\n"; - return true; - } - - // Only need primal (and shadow) request for reverse, or shadow buffer - if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" || - funcName == "PMPI_Isend" || funcName == "PMPI_Irecv") { - if (gutils->isConstantInstruction(const_cast(user))) - return false; - - if (val == CI->getArgOperand(6)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(qtype) << " request " << *val - << " in reverse for MPI " << *CI << "\n"; - return true; - } - if (shadow && val == CI->getArgOperand(0)) { - if ((funcName == "MPI_Irecv" || funcName == "PMPI_Irecv") && - mode != DerivativeMode::ReverseModeGradient) { - // Need shadow buffer for forward pass of irecieve - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow(" << to_string(qtype) << ") of " - << *val << " in reverse as shadow MPI " << *CI << "\n"; - return true; - } - if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") { - // Need shadow buffer for forward or reverse pass of isend - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow(" << to_string(qtype) << ") of " - << *val << " in reverse as shadow MPI " << *CI << "\n"; - return true; - } - } - - return false; - } - - if (!shadow) { - - // Need the primal request in reverse. - if (funcName == "cuStreamSynchronize") - if (val == CI->getArgOperand(0)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: primal(" << to_string(qtype) << ") of " - << *val << " in reverse for cuda sync " << *CI << "\n"; - return true; - } - - // Only need the primal request. - if (funcName == "MPI_Wait" || funcName == "PMPI_Wait") - if (val != CI->getArgOperand(0)) - return false; - - // Only need element count for reverse of waitall - if (funcName == "MPI_Waitall" || funcName == "PMPI_Waitall") - if (val != CI->getArgOperand(0) || val != CI->getOperand(1)) - return false; - - } else { - // Don't need shadow of anything (all via cache for reverse), - // but need shadow of request for primal. - if (funcName == "MPI_Wait" || funcName == "PMPI_Wait") { - if (gutils->isConstantInstruction(const_cast(user))) - return false; - // Need shadow request in forward pass only - if (mode != DerivativeMode::ReverseModeGradient) - if (val == CI->getArgOperand(0)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow of " << *val - << " in reverse as shadow MPI " << *CI << "\n"; - return true; - } - return false; - } - } - - // Since adjoint of barrier is another barrier in reverse - // we still need even if instruction is inactive - if (!shadow) - if (funcName == "__kmpc_barrier" || funcName == "MPI_Barrier") { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from barrier " << *user << "\n"; - return true; - } - - // Since adjoint of GC preserve is another preserve in reverse - // we still need even if instruction is inactive - if (!shadow) - if (funcName == "llvm.julia.gc_preserve_begin") { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from gc " << *CI << "\n"; - return true; - } - - if (funcName == "julia.write_barrier" || - funcName == "julia.write_barrier_binding") { - // Use in a write barrier requires the shadow in the forward, even - // though the instruction is active. - if (shadow && (mode != DerivativeMode::ReverseModeGradient && - mode != DerivativeMode::ForwardModeSplit)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow of " << *val - << " in forward as shadow write_barrier " << *CI << "\n"; - return true; - } - if (shadow) { - auto sz = CI->arg_size(); - bool isStored = false; - // First pointer is the destination - for (size_t i = 1; i < sz; i++) - isStored |= val == CI->getArgOperand(i); - bool rematerialized = false; - if (isStored) - for (auto pair : gutils->backwardsOnlyShadows) - if (pair.second.stores.count(CI) && - !gutils->isConstantValue(pair.first)) { - rematerialized = true; - break; - } - - if (rematerialized) { - if (EnzymePrintDiffUse) - llvm::errs() - << " Need: shadow of " << *val - << " in rematerialized reverse as shadow write_barrier " << *CI - << "\n"; - return true; - } - } - } - - bool writeOnlyNoCapture = true; - - if (shouldDisableNoWrite(CI)) { - writeOnlyNoCapture = false; - } - for (size_t i = 0; i < CI->arg_size(); i++) { - if (val == CI->getArgOperand(i)) { - if (!isNoCapture(CI, i)) { - writeOnlyNoCapture = false; - break; - } - if (!isWriteOnly(CI, i)) { - writeOnlyNoCapture = false; - break; - } - } - } - - // Don't need the primal argument if it is write only and not captured - if (!shadow) - if (writeOnlyNoCapture) - return false; - - if (shadow) { - // Don't need the shadow argument if it is a pointer to pointers, which - // is only written since the shadow pointer store will have been - // completed in the forward pass. - if (writeOnlyNoCapture && - TR.query(const_cast(val))[{-1, -1}] == BaseType::Pointer && - mode == DerivativeMode::ReverseModeGradient) - return false; - - const Value *FV = CI->getCalledOperand(); - if (FV == val) { - if (!gutils->isConstantInstruction(const_cast(user)) || - !gutils->isConstantValue(const_cast((Value *)user))) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow of " << *val - << " in reverse as shadow call " << *CI << "\n"; - return true; - } - } - } - } - - if (shadow) { - if (isa(user)) { - bool notrev = mode != DerivativeMode::ReverseModeGradient; - if (gutils->shadowReturnUsed && notrev) { - - bool inst_cv = gutils->isConstantValue(const_cast(val)); - - if ((qtype == QueryType::ShadowByConstPrimal && inst_cv) || - (qtype == QueryType::Shadow && !inst_cv)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow(qtype=" << (int)qtype - << ",cv=" << inst_cv << ") of " << *val - << " in reverse as shadow return " << *user << "\n"; - return true; - } - } - return false; - } - - // With certain exceptions, assume active instructions require the - // shadow of the operand. - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit || - mode == DerivativeMode::ForwardModeError || - (!isa(user) && !isa(user) && - !isa(user) && !isa(user) && - !isPointerArithmeticInst(user, /*includephi*/ false, - /*includebin*/ false))) { - - bool inst_cv = gutils->isConstantValue(const_cast(val)); - - if (!inst_cv && - !gutils->isConstantInstruction(const_cast(user))) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: shadow of " << *val - << " in reverse as shadow inst " << *user << "\n"; - return true; - } - } - - // Now the remaining instructions are inactive, however note that - // a constant instruction may still require the use of the shadow - // in the forward pass, for example double* x = load double** y - // is a constant instruction, but needed in the forward. However, - // if the value [and from above also the instruction] is constant - // we don't need it. - if (gutils->isConstantValue( - const_cast((const llvm::Value *)user))) { - return false; - } - - // Now we don't need this value directly, but we may need it recursively - // in one the active value users - assert(recursiveUse); - *recursiveUse = true; - return false; - } - - bool neededFB = false; - if (auto CB = dyn_cast(const_cast(user))) { - neededFB = !callShouldNotUseDerivative(gutils, *CB); - } else { - neededFB = !gutils->isConstantInstruction(user) || - !gutils->isConstantValue(const_cast(user)); - } - if (neededFB) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need direct primal of " << *val - << " in reverse from fallback " << *user << "\n"; - } - return neededFB; -} - -void DifferentialUseAnalysis::dump(Graph &G) { - for (auto &pair : G) { - llvm::errs() << "[" << *pair.first.V << ", " << (int)pair.first.outgoing - << "]\n"; - for (auto N : pair.second) { - llvm::errs() << "\t[" << *N.V << ", " << (int)N.outgoing << "]\n"; - } - } -} - -/* Returns true if there is a path from source 's' to sink 't' in - residual graph. Also fills parent[] to store the path */ -void DifferentialUseAnalysis::bfs(const Graph &G, - const SetVector &Recompute, - std::map &parent) { - std::deque q; - for (auto V : Recompute) { - Node N(V, false); - parent.emplace(N, Node(nullptr, true)); - q.push_back(N); - } - - // Standard BFS Loop - while (!q.empty()) { - auto u = q.front(); - q.pop_front(); - auto found = G.find(u); - if (found == G.end()) - continue; - for (auto v : found->second) { - if (parent.find(v) == parent.end()) { - q.push_back(v); - parent.emplace(v, u); - } - } - } -} - -// Return 1 if next is better -// 0 if equal -// -1 if prev is better, or unknown -int DifferentialUseAnalysis::cmpLoopNest(Loop *prev, Loop *next) { - if (next == prev) - return 0; - if (next == nullptr) - return 1; - else if (prev == nullptr) - return -1; - for (Loop *L = prev; L != nullptr; L = L->getParentLoop()) { - if (L == next) - return 1; - } - return -1; -} - -void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, - const SetVector &Recomputes, - const SetVector &Intermediates, - SetVector &Required, - SetVector &MinReq, - const GradientUtils *gutils, - llvm::TargetLibraryInfo &TLI) { - Graph G; - for (auto V : Intermediates) { - G[Node(V, false)].insert(Node(V, true)); - forEachDifferentialUser( - [&](Value *U) { - if (Intermediates.count(U)) { - if (V != U) - G[Node(V, true)].insert(Node(U, false)); - } - }, - gutils, V); - } - for (auto pair : gutils->rematerializableAllocations) { - if (Intermediates.count(pair.first)) { - for (LoadInst *L : pair.second.loads) { - if (Intermediates.count(L)) { - if (L != pair.first) - G[Node(pair.first, true)].insert(Node(L, false)); - } - } - for (auto L : pair.second.loadLikeCalls) { - if (Intermediates.count(L.loadCall)) { - if (L.loadCall != pair.first) - G[Node(pair.first, true)].insert(Node(L.loadCall, false)); - } - } - } - } -#ifndef NDEBUG - for (auto R : Required) { - assert(Intermediates.count(R)); - } - for (auto R : Recomputes) { - assert(Intermediates.count(R)); - } -#endif - - Graph Orig = G; - - // Augment the flow while there is a path from source to sink - while (1) { - std::map parent; - bfs(G, Recomputes, parent); - Node end(nullptr, false); - for (auto req : Required) { - if (parent.find(Node(req, true)) != parent.end()) { - end = Node(req, true); - break; - } - } - if (end.V == nullptr) - break; - // update residual capacities of the edges and reverse edges - // along the path - Node v = end; - while (1) { - assert(parent.find(v) != parent.end()); - Node u = parent.find(v)->second; - assert(u.V != nullptr); - assert(G[u].count(v) == 1); - assert(G[v].count(u) == 0); - G[u].erase(v); - G[v].insert(u); - if (Recomputes.count(u.V) && u.outgoing == false) - break; - v = u; - } - } - - // Flow is maximum now, find vertices reachable from s - - std::map parent; - bfs(G, Recomputes, parent); - - SetVector todo; - - // Print all edges that are from a reachable vertex to - // non-reachable vertex in the original graph - for (auto &pair : Orig) { - if (parent.find(pair.first) != parent.end()) - for (auto N : pair.second) { - if (parent.find(N) == parent.end()) { - assert(pair.first.outgoing == 0 && N.outgoing == 1); - assert(pair.first.V == N.V); - MinReq.insert(N.V); - if (Orig.find(Node(N.V, true)) != Orig.end()) { - todo.insert(N.V); - } - } - } - } - - while (todo.size()) { - auto V = todo.front(); - todo.remove(V); - auto found = Orig.find(Node(V, true)); - assert(found != Orig.end()); - const auto &mp = found->second; - - assert(MinReq.count(V)); - - // Fix up non-cacheable calls to use their operand(s) instead - if (hasNoCache(V)) { - assert(!Required.count(V)); - MinReq.remove(V); - for (auto &pair : Orig) { - if (pair.second.count(Node(V, false))) { - MinReq.insert(pair.first.V); - todo.insert(pair.first.V); - } - } - continue; - } - - // When ambiguous, push to cache the last value in a computation chain - // This should be considered in a cost for the max flow - if (mp.size() == 1 && !Required.count(V)) { - bool potentiallyRecursive = - isa((*mp.begin()).V) && - OrigLI.isLoopHeader(cast((*mp.begin()).V)->getParent()); - int moreOuterLoop = cmpLoopNest( - OrigLI.getLoopFor(cast(V)->getParent()), - OrigLI.getLoopFor(cast(((*mp.begin()).V))->getParent())); - if (potentiallyRecursive) - continue; - if (moreOuterLoop == -1) - continue; - if (auto ASC = dyn_cast((*mp.begin()).V)) { - if (ASC->getDestAddressSpace() == 11 || - ASC->getDestAddressSpace() == 13) - continue; - if (ASC->getSrcAddressSpace() == 10 && ASC->getDestAddressSpace() == 0) - continue; - } - if (auto CI = dyn_cast((*mp.begin()).V)) { - if (CI->getType()->isPointerTy() && - CI->getType()->getPointerAddressSpace() == 13) - continue; - } - if (auto G = dyn_cast((*mp.begin()).V)) { - if (G->getType()->getPointerAddressSpace() == 13) - continue; - } - if (hasNoCache((*mp.begin()).V)) { - continue; - } - // If an allocation call, we cannot cache any "capturing" users - if (isAllocationCall(V, TLI) || isa(V)) { - auto next = (*mp.begin()).V; - bool noncapture = false; - if (isa(next) || isNVLoad(next)) { - noncapture = true; - } else if (auto CI = dyn_cast(next)) { - bool captures = false; - for (size_t i = 0; i < CI->arg_size(); i++) { - if (CI->getArgOperand(i) == V && !isNoCapture(CI, i)) { - captures = true; - break; - } - } - noncapture = !captures; - } - - if (!noncapture) - continue; - } - - if (moreOuterLoop == 1 || - (moreOuterLoop == 0 && - DL.getTypeSizeInBits(V->getType()) >= - DL.getTypeSizeInBits((*mp.begin()).V->getType()))) { - MinReq.remove(V); - auto nnode = (*mp.begin()).V; - MinReq.insert(nnode); - if (Orig.find(Node(nnode, true)) != Orig.end()) - todo.insert(nnode); - } - } - } - - // Fix up non-repeatable writing calls that chain within rematerialized - // allocations. We could iterate from the keys of the valuemap, but that would - // be a non-determinstic ordering. - for (auto V : Intermediates) { - auto found = gutils->rematerializableAllocations.find(V); - if (found == gutils->rematerializableAllocations.end()) - continue; - if (!found->second.nonRepeatableWritingCall) - continue; - - // We are already caching this allocation directly, we're fine - if (MinReq.count(V)) - continue; - - // If we are recomputing a load, we need to fix this. - bool needsLoad = false; - for (auto load : found->second.loads) - if (Intermediates.count(load) && !MinReq.count(load)) { - needsLoad = true; - break; - } - for (auto load : found->second.loadLikeCalls) - if (Intermediates.count(load.loadCall) && !MinReq.count(load.loadCall)) { - needsLoad = true; - break; - } - - if (!needsLoad) - continue; - - // Rewire the uses to cache the allocation directly. - // TODO: as further optimization, we can remove potentially unnecessary - // values that we are keeping for stores. - MinReq.insert(V); - } - - return; -} - -bool DifferentialUseAnalysis::callShouldNotUseDerivative( - const GradientUtils *gutils, CallBase &call) { - bool shadowReturnUsed = false; - auto smode = gutils->mode; - if (smode == DerivativeMode::ReverseModeGradient) - smode = DerivativeMode::ReverseModePrimal; - (void)gutils->getReturnDiffeType(&call, nullptr, &shadowReturnUsed, smode); - - bool useConstantFallback = - gutils->isConstantInstruction(&call) && - (gutils->isConstantValue(&call) || !shadowReturnUsed); - if (useConstantFallback && gutils->mode != DerivativeMode::ForwardMode && - gutils->mode != DerivativeMode::ForwardModeError) { - // if there is an escaping allocation, which is deduced needed in - // reverse pass, we need to do the recursive procedure to perform the - // free. - - // First test if the return is a potential pointer and needed for the - // reverse pass - bool escapingNeededAllocation = false; - - if (!isNoEscapingAllocation(&call)) { - escapingNeededAllocation = EnzymeGlobalActivity; - - std::map CacheResults; - for (auto pair : gutils->knownRecomputeHeuristic) { - if (!pair.second || gutils->unnecessaryIntermediates.count( - cast(pair.first))) { - CacheResults[UsageKey(pair.first, QueryType::Primal)] = false; - } - } - - if (!escapingNeededAllocation && - !(EnzymeJuliaAddrLoad && isSpecialPtr(call.getType()))) { - if (gutils->TR.anyPointer(&call)) { - auto found = gutils->knownRecomputeHeuristic.find(&call); - if (found != gutils->knownRecomputeHeuristic.end()) { - if (!found->second) { - CacheResults.erase(UsageKey(&call, QueryType::Primal)); - escapingNeededAllocation = - DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &call, - DerivativeMode::ReverseModeGradient, - CacheResults, gutils->notForAnalysis); - } - } else { - escapingNeededAllocation = - DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &call, - DerivativeMode::ReverseModeGradient, - CacheResults, gutils->notForAnalysis); - } - } - } - - // Next test if any allocation could be stored into one of the - // arguments. - if (!escapingNeededAllocation) - for (unsigned i = 0; i < call.arg_size(); ++i) { - Value *a = call.getOperand(i); - - if (EnzymeJuliaAddrLoad && isSpecialPtr(a->getType())) - continue; - - if (!gutils->TR.anyPointer(a)) - continue; - - auto vd = gutils->TR.query(a); - - if (!vd[{-1, -1}].isPossiblePointer()) - continue; - - if (isReadOnly(&call, i)) - continue; - - // An allocation could only be needed in the reverse pass if it - // escapes into an argument. However, is the parameter by which it - // escapes could capture the pointer, the rest of Enzyme's caching - // mechanisms cannot assume that the allocation itself is - // reloadable, since it may have been captured and overwritten - // elsewhere. - // TODO: this justification will need revisiting in the future as - // the caching algorithm becomes increasingly sophisticated. - if (!isNoCapture(&call, i)) - continue; - - escapingNeededAllocation = true; - } - } - - // If desired this can become even more aggressive by looking through the - // called function for any allocations. - if (auto F = getFunctionFromCall(&call)) { - SmallVector todo = {F}; - SmallPtrSet done; - bool seenAllocation = false; - while (todo.size() && !seenAllocation) { - auto cur = todo.pop_back_val(); - if (done.count(cur)) - continue; - done.insert(cur); - // assume empty functions allocate. - if (cur->empty()) { - // unless they are marked - if (isNoEscapingAllocation(cur)) - continue; - seenAllocation = true; - break; - } - auto UR = getGuaranteedUnreachable(cur); - for (auto &BB : *cur) { - if (UR.count(&BB)) - continue; - for (auto &I : BB) - if (auto CB = dyn_cast(&I)) { - if (isNoEscapingAllocation(CB)) - continue; - if (isAllocationCall(CB, gutils->TLI)) { - seenAllocation = true; - goto finish; - } - if (auto F = getFunctionFromCall(CB)) { - todo.push_back(F); - continue; - } - // Conservatively assume indirect functions allocate. - seenAllocation = true; - goto finish; - } - } - finish:; - } - if (!seenAllocation) - escapingNeededAllocation = false; - } - if (escapingNeededAllocation) - useConstantFallback = false; - } - return useConstantFallback; -} diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h deleted file mode 100644 index 56c49039c86f..000000000000 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ /dev/null @@ -1,547 +0,0 @@ -//===- DifferentialUseAnalysis.h - Determine values needed in reverse pass-===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the declaration of Differential USe Analysis -- an -// AD-specific analysis that deduces if a given value is needed in the reverse -// pass. -// -//===----------------------------------------------------------------------===// - -#ifndef ENZYME_DIFFERENTIALUSEANALYSIS_H_ -#define ENZYME_DIFFERENTIALUSEANALYSIS_H_ - -#include -#include - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Instruction.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" - -#include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" - -#include "DiffeGradientUtils.h" -#include "GradientUtils.h" -#include "LibraryFuncs.h" - -extern "C" { -extern llvm::cl::opt EnzymePrintDiffUse; -} - -extern llvm::StringMap< - std::function> - customDiffUseHandlers; - -/// Classification of what type of use is requested -enum class QueryType { - // The original value is needed for the derivative - Primal = 0, - // The shadow value is needed for the derivative - Shadow = 1, - // The primal value is needed to stand in for the shadow - // value and compute the derivative of an instruction - ShadowByConstPrimal = 2 -}; - -static inline std::string to_string(QueryType mode) { - switch (mode) { - case QueryType::Primal: - return "Primal"; - case QueryType::Shadow: - return "Shadow"; - case QueryType::ShadowByConstPrimal: - return "ShadowByConstPrimal"; - } - llvm_unreachable("illegal QueryType"); -} - -typedef std::pair UsageKey; - -namespace DifferentialUseAnalysis { - -/// Determine if a value is needed directly to compute the adjoint -/// of the given instruction user. `shadow` denotes whether we are considering -/// the shadow of the value (shadow=true) or the primal of the value -/// (shadow=false). -/// Recursive use is only usable in shadow mode. -bool is_use_directly_needed_in_reverse( - const GradientUtils *gutils, const llvm::Value *val, DerivativeMode mode, - const llvm::Instruction *user, - const llvm::SmallPtrSetImpl &oldUnreachable, - QueryType shadow, bool *recursiveUse = nullptr); - -template -inline bool is_value_needed_in_reverse( - const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode, - std::map &seen, - const llvm::SmallPtrSetImpl &oldUnreachable) { - using namespace llvm; - - TypeResults const &TR = gutils->TR; - static_assert(VT == QueryType::Primal || VT == QueryType::Shadow || - VT == QueryType::ShadowByConstPrimal); - auto idx = UsageKey(inst, VT); - if (seen.find(idx) != seen.end()) - return seen[idx]; - if (auto ainst = dyn_cast(inst)) { - assert(ainst->getParent()->getParent() == gutils->oldFunc); - } - - // Inductively claim we aren't needed (and try to find contradiction) - seen[idx] = false; - - if (VT == QueryType::Primal) { - if (auto op = dyn_cast(inst)) { - if (op->getOpcode() == Instruction::FDiv) { - if (!gutils->isConstantValue(const_cast(inst)) && - !gutils->isConstantValue(op->getOperand(1))) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as is active div\n"; - return seen[idx] = true; - } - } - } - if (gutils->mode == DerivativeMode::ForwardModeError && - !gutils->isConstantValue(const_cast(inst))) { - if (EnzymePrintDiffUse) - llvm::errs() - << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as forward mode error always needs result\n"; - return seen[idx] = true; - } - } - - if (auto CI = dyn_cast(inst)) { - StringRef funcName = getFuncNameFromCall(const_cast(CI)); - if (funcName == "julia.get_pgcstack" || funcName == "julia.ptls_states") - return seen[idx] = true; - } - - bool inst_cv = gutils->isConstantValue(const_cast(inst)); - - // Consider all users of this value, do any of them need this in the reverse? - for (auto use : inst->users()) { - if (use == inst) - continue; - - const Instruction *user = dyn_cast(use); - - // A shadow value is only needed in reverse if it or one of its descendants - // is used in an active instruction. - // If inst is a constant value, the primal may be used in its place and - // thus required. - if (VT == QueryType::Shadow || VT == QueryType::ShadowByConstPrimal || - inst_cv) { - bool recursiveUse = false; - if (is_use_directly_needed_in_reverse( - gutils, inst, mode, user, oldUnreachable, - (VT == QueryType::Shadow) ? QueryType::Shadow - : QueryType::ShadowByConstPrimal, - &recursiveUse)) { - return seen[idx] = true; - } - - if (recursiveUse && !OneLevel) { - bool val; - if (VT == QueryType::Shadow) - val = is_value_needed_in_reverse( - gutils, user, mode, seen, oldUnreachable); - else - val = is_value_needed_in_reverse( - gutils, user, mode, seen, oldUnreachable); - if (val) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as shadow sub-need " << *user << "\n"; - return seen[idx] = true; - } - } - - if (!TR.anyFloat(const_cast(inst))) - if (auto IVI = dyn_cast(user)) { - bool inserted = false; - if (auto II = dyn_cast(IVI)) - inserted = II->getInsertedValueOperand() == inst || - II->getAggregateOperand() == inst; - if (auto II = dyn_cast(IVI)) - inserted = II->getAggregateOperand() == inst; - if (auto II = dyn_cast(IVI)) - inserted = II->getOperand(1) == inst || II->getOperand(0) == inst; - if (auto II = dyn_cast(IVI)) - inserted = II->getOperand(0) == inst; - if (inserted) { - SmallVector todo; - todo.push_back(IVI); - while (todo.size()) { - auto cur = todo.pop_back_val(); - for (auto u : cur->users()) { - if (auto IVI2 = dyn_cast(u)) { - todo.push_back(IVI2); - continue; - } - if (auto IVI2 = dyn_cast(u)) { - todo.push_back(IVI2); - continue; - } - if (auto IVI2 = dyn_cast(u)) { - todo.push_back(IVI2); - continue; - } - if (auto IVI2 = dyn_cast(u)) { - todo.push_back(IVI2); - continue; - } - - bool partial = false; - if (!gutils->isConstantValue(const_cast(cur))) { - partial = is_value_needed_in_reverse( - gutils, user, mode, seen, oldUnreachable); - } - if (partial) { - - if (EnzymePrintDiffUse) - llvm::errs() - << " Need (partial) direct " << to_string(VT) << " of " - << *inst << " in reverse from insertelem " << *user - << " via " << *cur << " in " << *u << "\n"; - return seen[idx] = true; - } - } - } - } - } - - if (VT != QueryType::Primal) - continue; - } - - assert(VT == QueryType::Primal); - - // If a sub user needs, we need - if (!OneLevel && is_value_needed_in_reverse(gutils, user, mode, seen, - oldUnreachable)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as sub-need " << *user << "\n"; - return seen[idx] = true; - } - - // Anything we may try to rematerialize requires its store operands for - // the reverse pass. - if (!OneLevel) { - bool isStored = false; - if (auto SI = dyn_cast(user)) - isStored = inst == SI->getValueOperand(); - else if (auto MTI = dyn_cast(user)) { - isStored = inst == MTI->getSource() || inst == MTI->getLength(); - } else if (auto MS = dyn_cast(user)) { - isStored = inst == MS->getLength() || inst == MS->getValue(); - } else if (auto CB = dyn_cast(user)) { - auto name = getFuncNameFromCall(CB); - if (name == "julia.write_barrier" || - name == "julia.write_barrier_binding") { - auto sz = CB->arg_size(); - // First pointer is the destination - for (size_t i = 1; i < sz; i++) - isStored |= inst == CB->getArgOperand(i); - } - } - if (isStored) { - for (auto pair : gutils->rematerializableAllocations) { - // If already decided to cache the whole allocation, ignore - if (gutils->needsCacheWholeAllocation(pair.first)) - continue; - - // If caching the outer allocation and have already set that this is - // not needed return early. This is necessary to avoid unnecessarily - // deciding stored values are needed if we have already decided to - // cache the whole allocation. - auto found = seen.find(std::make_pair(pair.first, QueryType::Primal)); - if (found != seen.end() && !found->second) - continue; - - // Directly consider all the load uses to avoid an illegal inductive - // recurrence. Specifically if we're asking if the alloca is used, - // we'll set it to unused, then check the gep, then here we'll - // directly say unused by induction instead of checking the final - // loads. - if (pair.second.stores.count(user)) { - for (LoadInst *L : pair.second.loads) - if (is_value_needed_in_reverse(gutils, L, mode, seen, - oldUnreachable)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as rematload " << *L << "\n"; - return seen[idx] = true; - } - for (auto &pair : pair.second.loadLikeCalls) - if (is_use_directly_needed_in_reverse( - gutils, pair.operand, mode, pair.loadCall, oldUnreachable, - QueryType::Primal) || - is_value_needed_in_reverse(gutils, pair.loadCall, mode, - seen, oldUnreachable)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as rematloadcall " - << *pair.loadCall << "\n"; - return seen[idx] = true; - } - } - } - } - } - - // One may need to this value in the computation of loop - // bounds/comparisons/etc (which even though not active -- will be used for - // the reverse pass) - // We could potentially optimize this to avoid caching if in combined mode - // and the instruction dominates all returns - // otherwise it will use the local cache (rather than save for a separate - // backwards cache) - // We also don't need this if looking at the shadow rather than primal - { - // Proving that none of the uses (or uses' uses) are used in control flow - // allows us to safely not do this load - - // TODO save loop bounds for dynamic loop - - // TODO make this more aggressive and dont need to save loop latch - if (isa(use) || isa(use)) { - size_t num = 0; - for (auto suc : successors(cast(use)->getParent())) { - if (!oldUnreachable.count(suc)) { - num++; - } - } - if (num <= 1) - continue; - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as control-flow " << *user << "\n"; - return seen[idx] = true; - } - - if (auto CI = dyn_cast(use)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "__kmpc_for_static_init_4" || - F->getName() == "__kmpc_for_static_init_4u" || - F->getName() == "__kmpc_for_static_init_8" || - F->getName() == "__kmpc_for_static_init_8u") { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as omp init " << *user << "\n"; - return seen[idx] = true; - } - } - } - } - - // The following are types we know we don't need to compute adjoints - - // If a primal value is needed to compute a shadow pointer (e.g. int offset - // in gep), it needs preserving. - bool primalUsedInShadowPointer = true; - if (isa(user) || isa(user)) - primalUsedInShadowPointer = false; - if (auto CI = dyn_cast(user)) { - auto funcName = getFuncNameFromCall(CI); - if (funcName == "julia.pointer_from_objref") { - primalUsedInShadowPointer = false; - } - if (funcName == "julia.gc_loaded") { - primalUsedInShadowPointer = false; - } - if (funcName.contains("__enzyme_todense")) { - primalUsedInShadowPointer = false; - } - } - if (auto GEP = dyn_cast(user)) { - bool idxUsed = false; - for (auto &idx : GEP->indices()) { - if (idx.get() == inst) - idxUsed = true; - } - if (!idxUsed) - primalUsedInShadowPointer = false; - } - if (auto II = dyn_cast(user)) { - if (isIntelSubscriptIntrinsic(*II)) { - const std::array idxArgsIndices{{0, 1, 2, 4}}; - bool idxUsed = false; - for (auto i : idxArgsIndices) { - if (II->getOperand(i) == inst) - idxUsed = true; - } - if (!idxUsed) - primalUsedInShadowPointer = false; - } - } - // No need for insert/extractvalue since indices are unsigned - // not llvm runtime values - if (isa(user) || isa(user)) - primalUsedInShadowPointer = false; - - if (primalUsedInShadowPointer) - if (!user->getType()->isVoidTy() && - TR.anyPointer(const_cast(user))) { - if (is_value_needed_in_reverse( - gutils, user, mode, seen, oldUnreachable)) { - if (EnzymePrintDiffUse) - llvm::errs() << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as used to compute shadow ptr " - << *user << "\n"; - return seen[idx] = true; - } - } - - bool direct = is_use_directly_needed_in_reverse( - gutils, inst, mode, user, oldUnreachable, QueryType::Primal); - if (!direct) - continue; - - if (inst->getType()->isTokenTy()) { - llvm::errs() << " need " << *inst << " via " << *user << "\n"; - } - assert(!inst->getType()->isTokenTy()); - - return seen[idx] = true; - } - return false; -} - -template -static inline bool is_value_needed_in_reverse( - const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode, - const llvm::SmallPtrSetImpl &oldUnreachable) { - static_assert(VT == QueryType::Primal || VT == QueryType::Shadow); - std::map seen; - return is_value_needed_in_reverse(gutils, inst, mode, seen, - oldUnreachable); -} - -struct Node { - llvm::Value *V; - bool outgoing; - Node(llvm::Value *V, bool outgoing) : V(V), outgoing(outgoing){}; - bool operator<(const Node N) const { - if (V < N.V) - return true; - return !(N.V < V) && outgoing < N.outgoing; - } - void dump() { - if (V) - llvm::errs() << "[" << *V << ", " << (int)outgoing << "]\n"; - else - llvm::errs() << "[" << V << ", " << (int)outgoing << "]\n"; - } -}; - -using Graph = std::map>; - -void dump(std::map> &G); - -/* Returns true if there is a path from source 's' to sink 't' in - residual graph. Also fills parent[] to store the path */ -void bfs(const std::map> &G, - const llvm::SetVector &Recompute, - std::map &parent); - -// Return 1 if next is better -// 0 if equal -// -1 if prev is better, or unknown -int cmpLoopNest(llvm::Loop *prev, llvm::Loop *next); - -void minCut(const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI, - const llvm::SetVector &Recomputes, - const llvm::SetVector &Intermediates, - llvm::SetVector &Required, - llvm::SetVector &MinReq, const GradientUtils *gutils, - llvm::TargetLibraryInfo &TLI); - -__attribute__((always_inline)) static inline void -forEachDirectInsertUser(llvm::function_ref f, - const GradientUtils *gutils, llvm::Instruction *IVI, - llvm::Value *val, bool useCheck) { - using namespace llvm; - if (!gutils->isConstantValue(IVI)) - return; - bool inserted = false; - if (auto II = dyn_cast(IVI)) - inserted = II->getInsertedValueOperand() == val || - II->getAggregateOperand() == val; - if (auto II = dyn_cast(IVI)) - inserted = II->getAggregateOperand() == val; - if (auto II = dyn_cast(IVI)) - inserted = II->getOperand(1) == val || II->getOperand(0) == val; - if (auto II = dyn_cast(IVI)) - inserted = II->getOperand(0) == val; - if (inserted) { - SmallVector todo; - todo.push_back(IVI); - while (todo.size()) { - auto cur = todo.pop_back_val(); - for (auto u : cur->users()) { - if (isa(u) || isa(u) || - isa(u) || isa(u)) { - auto I2 = cast(u); - bool subCheck = useCheck; - if (!subCheck) { - subCheck = is_value_needed_in_reverse( - gutils, I2, gutils->mode, gutils->notForAnalysis); - } - if (subCheck) - f(I2); - todo.push_back(I2); - continue; - } - } - } - } -} - -__attribute__((always_inline)) static inline void -forEachDifferentialUser(llvm::function_ref f, - const GradientUtils *gutils, llvm::Value *V, - bool useCheck = false) { - for (auto V2 : V->users()) { - if (auto Inst = llvm::dyn_cast(V2)) { - for (const auto &pair : gutils->rematerializableAllocations) { - if (pair.second.stores.count(Inst)) { - f(llvm::cast(pair.first)); - } - } - f(Inst); - forEachDirectInsertUser(f, gutils, Inst, V, useCheck); - } - } -} - -//! Return whether or not this is a constant and should use reverse pass -bool callShouldNotUseDerivative(const GradientUtils *gutils, - llvm::CallBase &orig); - -}; // namespace DifferentialUseAnalysis - -#endif diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 3b99de1665f2..a2e981fa27a2 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -26,6 +26,8 @@ #include #include +#include "llvm/ADT/StringRef.h" + #if LLVM_VERSION_MAJOR >= 16 #define private public #include "llvm/Analysis/ScalarEvolution.h" @@ -47,6 +49,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Passes/PassBuilder.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -73,71 +76,366 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" + +#include "llvm/Transforms/IPO/GlobalOpt.h" + +#include "llvm/Linker/Linker.h" +#include "llvm/IRReader/IRReader.h" + using namespace llvm; #ifdef DEBUG_TYPE #undef DEBUG_TYPE #endif #define DEBUG_TYPE "lower-reactant-intrinsic" +namespace { + + constexpr char cudaLaunchSymbolName[] = "cudaLaunchKernel"; +constexpr char kernelPrefix[] = "__mlir_launch_kernel_"; +constexpr char kernelCoercedPrefix[] = "__mlir_launch_coerced_kernel_"; + +constexpr char cudaPushConfigName[] = "__cudaPushCallConfiguration"; +constexpr char cudaPopConfigName[] = "__cudaPopCallConfiguration"; + +SmallVector gatherCallers(Function *F) { + if (!F) + return {}; + SmallVector ToHandle; + for (auto User : F->users()) + if (auto CI = dyn_cast(User)) + if (CI->getCalledFunction() == F) + ToHandle.push_back(CI); + return ToHandle; +} + +void fixup(Module &M) { + auto LaunchKernelFunc = M.getFunction(cudaLaunchSymbolName); + if (!LaunchKernelFunc) + return; + + SmallPtrSet CoercedKernels; + for (CallInst *CI : gatherCallers(LaunchKernelFunc)) { + IRBuilder<> Builder(CI); + auto FuncPtr = CI->getArgOperand(0); + auto GridDim1 = CI->getArgOperand(1); + auto GridDim2 = CI->getArgOperand(2); + auto BlockDim1 = CI->getArgOperand(3); + auto BlockDim2 = CI->getArgOperand(4); + auto SharedMemSize = CI->getArgOperand(6); + auto StreamPtr = CI->getArgOperand(7); + SmallVector Args = { + FuncPtr, GridDim1, GridDim2, BlockDim1, + BlockDim2, SharedMemSize, StreamPtr, + }; + auto StubFunc = cast(CI->getArgOperand(0)); + for (auto &Arg : StubFunc->args()) + Args.push_back(&Arg); + SmallVector ArgTypes; + for (Value *V : Args) + ArgTypes.push_back(V->getType()); + auto MlirLaunchFunc = Function::Create( + FunctionType::get(Type::getVoidTy(M.getContext()), ArgTypes, + /*isVarAtg=*/false), + llvm::GlobalValue::ExternalLinkage, + kernelCoercedPrefix + StubFunc->getName(), M); + + CoercedKernels.insert(Builder.CreateCall(MlirLaunchFunc, Args)); + CI->eraseFromParent(); + } + + SmallVector InlinedStubs; + for (CallInst *CI : CoercedKernels) { + Function *StubFunc = cast(CI->getArgOperand(0)); + for (User *callee : StubFunc->users()) { + if (auto *CI = dyn_cast(callee)) { + if (CI->getCalledFunction() == StubFunc) { + InlineFunctionInfo IFI; + InlineResult Res = + InlineFunction(*CI, IFI, /*MergeAttributes=*/false); + assert(Res.isSuccess()); + InlinedStubs.push_back(StubFunc); + continue; + } + } + } + } + for (Function *F : InlinedStubs) { + F->erase(F->begin(), F->end()); + BasicBlock *BB = BasicBlock::Create(F->getContext(), "entry", F); + ReturnInst::Create(F->getContext(), nullptr, BB->begin()); + } + + CoercedKernels.clear(); + DenseMap> FuncAllocas; + auto PushConfigFunc = M.getFunction(cudaPushConfigName); + for (CallInst *CI : gatherCallers(PushConfigFunc)) { + Function *TheFunc = CI->getFunction(); + IRBuilder<> IRB(&TheFunc->getEntryBlock(), + TheFunc->getEntryBlock().getFirstNonPHIOrDbgOrAlloca()); + auto Allocas = FuncAllocas.lookup(TheFunc); + if (Allocas.empty()) { + Allocas.push_back( + IRB.CreateAlloca(IRB.getInt64Ty(), nullptr, "griddim64")); + Allocas.push_back( + IRB.CreateAlloca(IRB.getInt32Ty(), nullptr, "griddim32")); + Allocas.push_back( + IRB.CreateAlloca(IRB.getInt64Ty(), nullptr, "blockdim64")); + Allocas.push_back( + IRB.CreateAlloca(IRB.getInt32Ty(), nullptr, "blockdim32")); + Allocas.push_back( + IRB.CreateAlloca(IRB.getInt64Ty(), nullptr, "shmem_size")); + Allocas.push_back(IRB.CreateAlloca(IRB.getPtrTy(), nullptr, "stream")); + FuncAllocas.insert_or_assign(TheFunc, Allocas); + } + IRB.SetInsertPoint(CI); + for (auto [Arg, Alloca] : + llvm::zip_equal(llvm::drop_end(CI->operand_values()), Allocas)) + IRB.CreateStore(Arg, Alloca); + } + auto PopConfigFunc = M.getFunction(cudaPopConfigName); + for (CallInst *PopCall : gatherCallers(PopConfigFunc)) { + Function *TheFunc = PopCall->getFunction(); + auto Allocas = FuncAllocas.lookup(TheFunc); + if (Allocas.empty()) { + continue; + } + + CallInst *KernelLaunch = PopCall; + Instruction *It = PopCall; + do { + It = It->getNextNonDebugInstruction(); + KernelLaunch = dyn_cast(It); + } while (!It->isTerminator() && + !(KernelLaunch && KernelLaunch->getCalledFunction() && + KernelLaunch->getCalledFunction()->getName().starts_with( + kernelCoercedPrefix))); + + assert(!It->isTerminator()); + + IRBuilder<> IRB(PopCall); + + for (auto [Arg, Alloca] : llvm::zip( + llvm::drop_begin(KernelLaunch->operand_values(), 1), Allocas)) { + auto Load = cast(Arg); + LoadInst *NewLoad = IRB.CreateLoad(Arg->getType(), Alloca); + Load->replaceAllUsesWith(NewLoad); + } + CoercedKernels.insert(KernelLaunch); + + It = &*PopCall->getParent()->getPrevNode()->getFirstNonPHIOrDbg(); + CallInst *PushCall = dyn_cast(It); + while (!It->isTerminator() && + !(PushCall && PushCall->getCalledFunction() && + PushCall->getCalledFunction()->getName() == cudaPushConfigName)) { + It = It->getNextNonDebugInstruction(); + PushCall = dyn_cast(It); + } + + assert(!It->isTerminator()); + + // Replace with success + PushCall->replaceAllUsesWith(IRB.getInt32(0)); + PushCall->eraseFromParent(); + PopCall->replaceAllUsesWith(IRB.getInt32(0)); + PopCall->eraseFromParent(); + } + for (CallInst *CI : CoercedKernels) { + IRBuilder<> Builder(CI); + auto FuncPtr = CI->getArgOperand(0); + auto GridDim1 = CI->getArgOperand(1); + auto GridDim2 = CI->getArgOperand(2); + auto GridDimX = Builder.CreateTrunc(GridDim1, Builder.getInt32Ty()); + auto GridDimY = Builder.CreateLShr( + GridDim1, ConstantInt::get(Builder.getInt64Ty(), 32)); + GridDimY = Builder.CreateTrunc(GridDimY, Builder.getInt32Ty()); + auto GridDimZ = GridDim2; + auto BlockDim1 = CI->getArgOperand(3); + auto BlockDim2 = CI->getArgOperand(4); + auto BlockDimX = Builder.CreateTrunc(BlockDim1, Builder.getInt32Ty()); + auto BlockDimY = Builder.CreateLShr( + BlockDim1, ConstantInt::get(Builder.getInt64Ty(), 32)); + BlockDimY = Builder.CreateTrunc(BlockDimY, Builder.getInt32Ty()); + auto BlockDimZ = BlockDim2; + auto SharedMemSize = CI->getArgOperand(5); + auto StreamPtr = CI->getArgOperand(6); + SmallVector Args = { + FuncPtr, GridDimX, GridDimY, GridDimZ, BlockDimX, + BlockDimY, BlockDimZ, SharedMemSize, StreamPtr, + }; + auto StubFunc = cast(CI->getArgOperand(0)); + for (unsigned I = 7; I < CI->getNumOperands() - 1; I++) + Args.push_back(CI->getArgOperand(I)); + SmallVector ArgTypes; + for (Value *V : Args) + ArgTypes.push_back(V->getType()); + auto MlirLaunchFunc = Function::Create( + FunctionType::get(Type::getVoidTy(M.getContext()), ArgTypes, + /*isVarAtg=*/false), + llvm::GlobalValue::ExternalLinkage, kernelPrefix + StubFunc->getName(), + M); + + Builder.CreateCall(MlirLaunchFunc, Args); + CI->eraseFromParent(); + } +} class ReactantBase { public: - ReactantBase(bool PostOpt) { + std::vector gpubins; + ReactantBase(const std::vector &gpubins) : gpubins(gpubins) { } bool run(Module &M) { bool changed = true; - for (Function &F : make_early_inc_range(M)) { - if (!F.empty()) continue; - if (F.getName() == "cudaMalloc") { - auto entry = BasicBlock::Create(F.getContext(), "entry", &F); - IRBuilder<> B(entry); + llvm::errs() << " pre: " << M << "\n"; + fixup(M); + llvm::errs() << "M: " << M << "\n"; + + for (auto bin : gpubins) { + llvm::errs() << " gpubin: " << bin << "\n"; + + SMDiagnostic Err; + auto mod2 = llvm::parseIRFile(bin + ".re_export", Err, M.getContext()); + if (!mod2) { + Err.print(/*ProgName=*/"LLVMToMLIR", llvm::errs()); + exit(1); + } + + for (std::string T : {"", "f"}) { + for (std::string name : + {"sin", "cos", "tan", "log2", "exp", "exp2", + "exp10", "cosh", "sinh", "tanh", "atan2", "atan", + "asin", "acos", "log", "log10", "log1p", "acosh", + "asinh", "atanh", "expm1", "hypot", "rhypot", "norm3d", + "rnorm3d", "norm4d", "rnorm4d", "norm", "rnorm", "cbrt", + "rcbrt", "j0", "j1", "y0", "y1", "yn", + "jn", "erf", "erfinv", "erfc", "erfcx", "erfcinv", + "normcdfinv", "normcdf", "lgamma", "ldexp", "scalbn", "frexp", + "modf", "fmod", "remainder", "remquo", "powi", "tgamma", + "round", "fdim", "ilogb", "logb", "isinf", "pow", + "sqrt", "finite", "fabs", "fmax"}) { + std::string nvname = "__nv_" + name; + std::string llname = "llvm." + name + "."; + std::string mathname = name; + + if (T == "f") { + mathname += "f"; + nvname += "f"; + llname += "f32"; + } else { + llname += "f64"; + } - auto entry = new BasicBlock() - F.ad + if (auto F = mod2->getFunction(llname)) { + F->deleteBody(); } } + } + { - return changed; - } -}; -class ReactantOldPM : public ReactantBase, public ModulePass { -public: - static char ID; - EnzymeOldPM(bool PostOpt = false) : ReactantBase(PostOpt), ModulePass(ID) {} + PassBuilder PB; + LoopAnalysisManager LAM; + FunctionAnalysisManager FAM; + CGSCCAnalysisManager CGAM; + ModuleAnalysisManager MAM; + PB.registerModuleAnalyses(MAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.registerCGSCCAnalyses(CGAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired(); + GlobalOptPass().run(*mod2, MAM); + } + for (auto &F : *mod2) { + if (!F.empty()) + F.setLinkage(Function::LinkageTypes::InternalLinkage); + } + if (auto RF = M.getFunction("__cudaRegisterFunction")) { + for (auto U : make_early_inc_range(RF->users())) { + if (auto CI = dyn_cast(U)) { + if (CI->getCalledFunction() != RF) continue; + + Value *F2 = CI->getArgOperand(1); + Value *name = CI->getArgOperand(2); + while (auto CE = dyn_cast(F2)) { + F2 = CE->getOperand(0); + } + while (auto CE = dyn_cast(name)) { + name = CE->getOperand(0); + } + StringRef nameVal; + if (auto GV = dyn_cast(name)) + if (GV->isConstant()) + if (auto C = GV->getInitializer()) + if (auto CA = dyn_cast(C)) + if (CA->getType()->getElementType()->isIntegerTy(8) && + CA->isCString()) + nameVal = CA->getAsCString(); + auto F22 = dyn_cast(F2); + if (!F22) continue; + + if (nameVal.size()) + if (auto MF = mod2->getFunction(nameVal)) { + MF->setName(F22->getName()); + F22->deleteBody(); + MF->setCallingConv(llvm::CallingConv::C); + MF->setLinkage(Function::LinkageTypes::LinkOnceODRLinkage); + } + CI->eraseFromParent(); + } + } + } - // AU.addRequiredID(LCSSAID); + llvm::errs() << " mod2: " << *mod2 << "\n"; + + auto handler = M.getContext().getDiagnosticHandler(); + Linker L(M); + L.linkInModule(std::move(mod2)); + M.getContext().setDiagnosticHandler(std::move(handler)); + } - // LoopInfo is required to ensure that all loops have preheaders - // AU.addRequired(); + for (Function &F : make_early_inc_range(M)) { + if (!F.empty()) continue; + if (F.getName() == "cudaMalloc") { + continue; + auto entry = BasicBlock::Create(F.getContext(), "entry", &F); + IRBuilder<> B(entry); + } + } - // AU.addRequiredID(llvm::LoopSimplifyID);//(); + fixup(M); + for (auto todel : {"__cuda_register_globals", "__cuda_module_ctor", "__cuda_module_dtor"}) { + if (auto F = M.getFunction(todel)) { + F->replaceAllUsesWith(Constant::getNullValue(F->getType())); + F->eraseFromParent(); + } + } + { + PassBuilder PB; + LoopAnalysisManager LAM; + FunctionAnalysisManager FAM; + CGSCCAnalysisManager CGAM; + ModuleAnalysisManager MAM; + PB.registerModuleAnalyses(MAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.registerCGSCCAnalyses(CGAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + GlobalOptPass().run(M, MAM); + } + llvm::errs() << "M: " << M << "\n"; + return changed; } - bool runOnModule(Module &M) override { return run(M); } }; -} // namespace - -char ReactantOldPM::ID = 0; - -static RegisterPass X("enzyme", "Enzyme Pass"); - -ModulePass *createReactantPass(bool PostOpt) { return new EnzymeOldPM(PostOpt); } +} #include #include -#include "llvm/IR/LegacyPassManager.h" - -extern "C" void AddReactantPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createReactantPass(/*PostOpt*/ false)); -} - #include "llvm/Passes/PassPlugin.h" class ReactantNewPM final : public ReactantBase, @@ -149,9 +447,12 @@ class ReactantNewPM final : public ReactantBase, public: using Result = llvm::PreservedAnalyses; - ReactantNewPM(bool PostOpt = false) : ReactantBase(PostOpt) {} + ReactantNewPM(const std::vector &gpubins) : ReactantBase(gpubins) { + llvm::errs() << " constructing new pm\n"; +} Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { +llvm::errs() << " running on module: " << M << "\n"; return ReactantBase::run(M) ? PreservedAnalyses::none() : PreservedAnalyses::all(); } @@ -159,66 +460,57 @@ class ReactantNewPM final : public ReactantBase, static bool isRequired() { return true; } }; +class ExporterNewPM final : public AnalysisInfoMixin { + friend struct llvm::AnalysisInfoMixin; + +private: + static llvm::AnalysisKey Key; + +public: + using Result = llvm::PreservedAnalyses; + std::string firstfile; + ExporterNewPM(std::string file) : firstfile(file) {} + + Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { + std::string filename = firstfile + ".re_export"; + + std::error_code EC; + llvm::raw_fd_ostream file(filename, EC);//, llvm::sys::fs::OF_Text); + + if (EC) { + llvm::errs() << "Error opening file: " << EC.message() << "\n"; + exit(1); + } + + file << M; + llvm::errs() << " exported to: " << filename << "\n"; + return PreservedAnalyses::all(); + } + + static bool isRequired() { return true; } +}; + #undef DEBUG_TYPE AnalysisKey ReactantNewPM::Key; +AnalysisKey ExporterNewPM::Key; -#include "ActivityAnalysisPrinter.h" -#include "JLInstSimplify.h" -#include "PreserveNVVM.h" -#include "TypeAnalysis/TypeAnalysisPrinter.h" #include "llvm/Passes/PassBuilder.h" -#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" -#include "llvm/Transforms/IPO/AlwaysInliner.h" -#include "llvm/Transforms/IPO/CalledValuePropagation.h" -#include "llvm/Transforms/IPO/ConstantMerge.h" -#include "llvm/Transforms/IPO/CrossDSOCFI.h" -#include "llvm/Transforms/IPO/DeadArgumentElimination.h" -#include "llvm/Transforms/IPO/FunctionAttrs.h" -#include "llvm/Transforms/IPO/GlobalDCE.h" -#include "llvm/Transforms/IPO/GlobalOpt.h" -#include "llvm/Transforms/IPO/GlobalSplit.h" -#include "llvm/Transforms/IPO/InferFunctionAttrs.h" -#include "llvm/Transforms/IPO/SCCP.h" -#include "llvm/Transforms/InstCombine/InstCombine.h" -#include "llvm/Transforms/Scalar/CallSiteSplitting.h" -#include "llvm/Transforms/Scalar/EarlyCSE.h" -#include "llvm/Transforms/Scalar/Float2Int.h" -#include "llvm/Transforms/Scalar/GVN.h" -#include "llvm/Transforms/Scalar/LoopDeletion.h" -#include "llvm/Transforms/Scalar/LoopRotation.h" -#include "llvm/Transforms/Scalar/LoopUnrollPass.h" -#include "llvm/Transforms/Scalar/SROA.h" -// #include "llvm/Transforms/IPO/MemProfContextDisambiguation.h" -#include "llvm/Transforms/IPO/ArgumentPromotion.h" -#include "llvm/Transforms/Scalar/ConstraintElimination.h" -#include "llvm/Transforms/Scalar/DeadStoreElimination.h" -#include "llvm/Transforms/Scalar/JumpThreading.h" -#include "llvm/Transforms/Scalar/MemCpyOptimizer.h" -#include "llvm/Transforms/Scalar/NewGVN.h" -#include "llvm/Transforms/Scalar/TailRecursionElimination.h" -#if LLVM_VERSION_MAJOR >= 17 -#include "llvm/Transforms/Utils/MoveAutoInit.h" -#endif -#include "llvm/Transforms/Scalar/IndVarSimplify.h" -#include "llvm/Transforms/Scalar/LICM.h" -#include "llvm/Transforms/Scalar/LoopFlatten.h" -#include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h" -void augmentPassBuilder(llvm::PassBuilder &PB) { +extern "C" void registerExporter(llvm::PassBuilder &PB, std::string file) { #if LLVM_VERSION_MAJOR >= 20 - auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level, + auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level, ThinOrFullLTOPhase) #else - auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level) + auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level) #endif { - MPM.addPass(ReactantNewPM()); + MPM.addPass(ExporterNewPM(file)); }; // TODO need for perf reasons to move Enzyme pass to the pre vectorization. PB.registerOptimizerEarlyEPCallback(loadPass); - auto loadLTO = [preLTOPass, loadPass](ModulePassManager &MPM, + auto loadLTO = [loadPass](ModulePassManager &MPM, OptimizationLevel Level) { #if LLVM_VERSION_MAJOR >= 20 loadPass(MPM, Level, ThinOrFullLTOPhase::None); @@ -233,21 +525,29 @@ extern "C" void registerReactantAndPassPipeline(llvm::PassBuilder &PB, bool augment = false) { } -extern "C" void registerReactant(llvm::PassBuilder &PB) { +extern "C" void registerReactant(llvm::PassBuilder &PB, std::vector gpubinaries) { - PB.registerPipelineParsingCallback( - [](llvm::StringRef Name, llvm::ModulePassManager &MPM, - llvm::ArrayRef) { - if (Name == "reactant") { - MPM.addPass(ReactantNewPM()); - return true; - } - return false; - }); - registerReactantAndPassPipeline(PB, /*augment*/ false); -} + llvm::errs() << " registering reactant\n"; +#if LLVM_VERSION_MAJOR >= 20 + auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level, + ThinOrFullLTOPhase) +#else + auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level) +#endif + { + MPM.addPass(ReactantNewPM(gpubinaries)); + }; + + // TODO need for perf reasons to move Enzyme pass to the pre vectorization. + PB.registerOptimizerEarlyEPCallback(loadPass); -extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK -llvmGetPassPluginInfo() { - return {LLVM_PLUGIN_API_VERSION, "ReactantNewPM", "v0.1", registerReactant}; + auto loadLTO = [loadPass](ModulePassManager &MPM, + OptimizationLevel Level) { +#if LLVM_VERSION_MAJOR >= 20 + loadPass(MPM, Level, ThinOrFullLTOPhase::None); +#else + loadPass(MPM, Level); +#endif + }; + PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO); } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp deleted file mode 100644 index 48db98150244..000000000000 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ /dev/null @@ -1,6610 +0,0 @@ -//===- EnzymeLogic.cpp - Implementation of forward and reverse pass generation// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file defines two functions CreatePrimalAndGradient and -// CreateAugmentedPrimal. CreatePrimalAndGradient takes a function, known -// TypeResults of the calling context, known activity analysis of the -// arguments. It creates a corresponding gradient -// function, computing the primal as well if requested. -// CreateAugmentedPrimal takes similar arguments and creates an augmented -// primal pass. -// -//===----------------------------------------------------------------------===// -#include "EnzymeLogic.h" -#include "ActivityAnalysis.h" -#include "AdjointGenerator.h" -#include "EnzymeLogic.h" -#include "TypeAnalysis/TypeAnalysis.h" -#include "llvm/IR/Constant.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/GlobalValue.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/Support/ErrorHandling.h" -#include - -#if LLVM_VERSION_MAJOR >= 16 -#define private public -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#undef private -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "llvm/Analysis/DependenceAnalysis.h" -#include - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Verifier.h" - -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueTracking.h" - -#include "llvm/Demangle/Demangle.h" - -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" - -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" - -#include "llvm/Support/AMDGPUMetadata.h" -#include "llvm/Support/TimeProfiler.h" - -#include "llvm/ADT/StringSet.h" - -#include "DiffeGradientUtils.h" -#include "FunctionUtils.h" -#include "GradientUtils.h" -#include "InstructionBatcher.h" -#include "LibraryFuncs.h" -#include "TraceGenerator.h" -#include "Utils.h" - -#define addAttribute addAttributeAtIndex -#define getAttribute getAttributeAtIndex -#define removeAttribute removeAttributeAtIndex - -using namespace llvm; - -extern "C" { -llvm::cl::opt - EnzymePrint("enzyme-print", cl::init(false), cl::Hidden, - cl::desc("Print before and after fns for autodiff")); - -llvm::cl::opt - EnzymePrintUnnecessary("enzyme-print-unnecessary", cl::init(false), - cl::Hidden, - cl::desc("Print unnecessary values in function")); - -cl::opt looseTypeAnalysis("enzyme-loose-types", cl::init(false), - cl::Hidden, - cl::desc("Allow looser use of types")); - -cl::opt nonmarkedglobals_inactiveloads( - "enzyme_nonmarkedglobals_inactiveloads", cl::init(true), cl::Hidden, - cl::desc("Consider loads of nonmarked globals to be inactive")); - -cl::opt EnzymeJuliaAddrLoad( - "enzyme-julia-addr-load", cl::init(false), cl::Hidden, - cl::desc("Mark all loads resulting in an addr(13)* to be legal to redo")); - -cl::opt EnzymeAssumeUnknownNoFree( - "enzyme-assume-unknown-nofree", cl::init(false), cl::Hidden, - cl::desc("Assume unknown instructions are nofree as needed")); - -LLVMValueRef (*EnzymeFixupReturn)(LLVMBuilderRef, LLVMValueRef) = nullptr; -} - -struct CacheAnalysis { - - const ValueMap> - &allocationsWithGuaranteedFree; - const ValueMap - &rematerializableAllocations; - TypeResults &TR; - AAResults &AA; - Function *oldFunc; - ScalarEvolution &SE; - LoopInfo &OrigLI; - DominatorTree &OrigDT; - TargetLibraryInfo &TLI; - const SmallPtrSetImpl &unnecessaryBlocks; - const bool subsequent_calls_may_write; - const std::vector &overwritten_args; - DerivativeMode mode; - std::map seen; - bool omp; - CacheAnalysis( - const ValueMap> - &allocationsWithGuaranteedFree, - const ValueMap - &rematerializableAllocations, - TypeResults &TR, AAResults &AA, Function *oldFunc, ScalarEvolution &SE, - LoopInfo &OrigLI, DominatorTree &OrigDT, TargetLibraryInfo &TLI, - const SmallPtrSetImpl &unnecessaryBlocks, - bool subsequent_calls_may_write, - const std::vector &overwritten_args, DerivativeMode mode, bool omp) - : allocationsWithGuaranteedFree(allocationsWithGuaranteedFree), - rematerializableAllocations(rematerializableAllocations), TR(TR), - AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), OrigDT(OrigDT), - TLI(TLI), unnecessaryBlocks(unnecessaryBlocks), - subsequent_calls_may_write(subsequent_calls_may_write), - overwritten_args(overwritten_args), mode(mode), omp(omp) {} - - bool is_value_mustcache_from_origin(Value *obj) { - if (seen.find(obj) != seen.end()) - return seen[obj]; - - bool mustcache = false; - - // If the pointer operand is from an argument to the function, we need to - // check if the argument - // received from the caller is uncacheable. - if (rematerializableAllocations.count(obj)) { - return false; - } else if (isa(obj) || isa(obj)) { - return false; - } else if (auto arg = dyn_cast(obj)) { - if (arg->getArgNo() >= overwritten_args.size()) { - llvm::errs() << "overwritten_args:\n"; - for (auto pair : overwritten_args) { - llvm::errs() << " + " << pair << "\n"; - } - llvm::errs() << "could not find " << *arg << " of func " - << arg->getParent()->getName() << " in args_map\n"; - llvm_unreachable("could not find arg in args_map"); - } - if (overwritten_args[arg->getArgNo()]) { - mustcache = true; - // EmitWarning("UncacheableOrigin", *arg, - // "origin arg may need caching ", *arg); - } - } else if (auto pn = dyn_cast(obj)) { - seen[pn] = false; - for (auto &val : pn->incoming_values()) { - if (is_value_mustcache_from_origin(val)) { - mustcache = true; - EmitWarning("UncacheableOrigin", *pn, "origin pn may need caching ", - *pn); - break; - } - } - } else if (auto ci = dyn_cast(obj)) { - mustcache = is_value_mustcache_from_origin(ci->getOperand(0)); - if (mustcache) { - EmitWarning("UncacheableOrigin", *ci, "origin ci may need caching ", - *ci); - } - } else if (auto gep = dyn_cast(obj)) { - mustcache = is_value_mustcache_from_origin(gep->getPointerOperand()); - if (mustcache) { - EmitWarning("UncacheableOrigin", *gep, "origin gep may need caching ", - *gep); - } - } else if (auto II = dyn_cast(obj); - II && isIntelSubscriptIntrinsic(*II)) { - mustcache = is_value_mustcache_from_origin(II->getOperand(3)); - if (mustcache) { - EmitWarning("UncacheableOrigin", *II, - "origin llvm.intel.subscript may need caching ", *II); - } - } else { - - // Pointer operands originating from call instructions that are not - // malloc/free are conservatively considered uncacheable. - if (auto obj_op = dyn_cast(obj)) { - auto n = getFuncNameFromCall(obj_op); - // If this is a known allocation which is not captured or returned, - // a caller function cannot overwrite this (since it cannot access). - // Since we don't currently perform this check, we can instead check - // if the pointer has a guaranteed free (which is a weaker form of - // the required property). - if (allocationsWithGuaranteedFree.find(obj_op) != - allocationsWithGuaranteedFree.end()) { - - } else if (n == "julia.get_pgcstack" || n == "julia.ptls_states" || - n == "jl_get_ptls_states") { - - } else { - // OP is a non malloc/free call so we need to cache - mustcache = true; - EmitWarning("UncacheableOrigin", *obj_op, - "origin call may need caching ", *obj_op); - } - } else if (isa(obj)) { - // No change to modref if alloca since the memory only exists in - // this function. - } else if (auto GV = dyn_cast(obj)) { - // In the absense of more fine-grained global info, assume object is - // written to in a subseqent call unless this is known to be constant - if (!GV->isConstant()) { - mustcache = true; - } - } else { - // In absence of more information, assume that the underlying object for - // pointer operand is uncacheable in caller. - mustcache = true; - if (auto I = dyn_cast(obj)) - EmitWarning("UncacheableOrigin", *I, - "unknown origin may need caching ", *obj); - } - } - - return seen[obj] = mustcache; - } - - bool is_load_uncacheable(Instruction &li) { - assert(li.getParent()->getParent() == oldFunc); - - auto Arch = llvm::Triple(oldFunc->getParent()->getTargetTriple()).getArch(); - if (Arch == Triple::amdgcn && - cast(li.getOperand(0)->getType())->getAddressSpace() == - 4) { - return false; - } - - if (hasNoCache(&li)) - return false; - - if (EnzymeJuliaAddrLoad) - if (auto PT = dyn_cast(li.getType())) - if (PT->getAddressSpace() == 13) - return false; - - // Only use invariant load data if either, we are not using Julia - // or we can guarantee that no following instruction will write to memory. - // The reason for this is that Julia - // incorrectly has invariant load info for a function, which specifies - // the load value won't change over the course of a function, but - // may change from a caller. - bool checkFunction = true; - if (li.hasMetadata(LLVMContext::MD_invariant_load)) { - if (!EnzymeJuliaAddrLoad || !subsequent_calls_may_write) - return false; - else - checkFunction = false; - } - - // Find the underlying object for the pointer operand of the load - // instruction. - auto obj = getBaseObject(li.getOperand(0)); - - if (auto obj_op = dyn_cast(obj)) { - auto n = getFuncNameFromCall(obj_op); - if (n == "julia.get_pgcstack" || n == "julia.ptls_states" || - n == "jl_get_ptls_states") - return false; - } - if (auto objli = dyn_cast(obj)) { - auto obj2 = getBaseObject(objli->getOperand(0)); - if (auto obj_op = dyn_cast(obj2)) { - auto n = getFuncNameFromCall(obj_op); - if (n == "julia.get_pgcstack" || n == "julia.ptls_states" || - n == "jl_get_ptls_states") - return false; - } - } - - // Openmp bound and local thread id are unchanging - // definitionally cacheable. - if (omp) - if (auto arg = dyn_cast(obj)) { - if (arg->getArgNo() < 2) { - return false; - } - } - - // Any load from a rematerializable allocation is definitionally - // reloadable. Notably we don't need to perform the allFollowers - // of check as the loop scope caching should allow us to ignore - // such stores. - if (rematerializableAllocations.count(obj)) - return false; - - // If not running combined, check if pointer operand is overwritten - // by a subsequent call (i.e. not this function). - bool can_modref = false; - if (subsequent_calls_may_write) - can_modref = is_value_mustcache_from_origin(obj); - - if (!can_modref && checkFunction) { - allFollowersOf(&li, [&](Instruction *inst2) { - if (!inst2->mayWriteToMemory()) - return false; - - if (isa(inst2)) - return false; - - if (unnecessaryBlocks.count(inst2->getParent())) { - return false; - } - if (auto CI = dyn_cast(inst2)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "__kmpc_for_static_fini") { - return false; - } - } - } - - if (!overwritesToMemoryReadBy(&TR, AA, TLI, SE, OrigLI, OrigDT, &li, - inst2)) { - return false; - } - - if (auto II = dyn_cast(inst2)) { - if (II->getIntrinsicID() == Intrinsic::nvvm_barrier0 || - II->getIntrinsicID() == Intrinsic::amdgcn_s_barrier) { - allUnsyncdPredecessorsOf( - II, - [&](Instruction *mid) { - if (!mid->mayWriteToMemory()) - return false; - - if (isa(mid)) - return false; - - if (unnecessaryBlocks.count(mid->getParent())) { - return false; - } - - if (!writesToMemoryReadBy(&TR, AA, TLI, &li, mid)) { - return false; - } - - can_modref = true; - EmitWarning("Uncacheable", li, "Load may need caching ", li, - " due to ", *mid, " via ", *II); - return true; - }, - [&]() { - // if gone past entry - if (mode != DerivativeMode::ReverseModeCombined) { - EmitWarning("Uncacheable", li, "Load may need caching ", li, - " due to entry via ", *II); - can_modref = true; - } - }); - if (can_modref) - return true; - else - return false; - } - } - can_modref = true; - EmitWarning("Uncacheable", li, "Load may need caching ", li, " due to ", - *inst2); - // Early exit - return true; - }); - } else { - - EmitWarning("Uncacheable", li, "Load may need caching ", li, - " due to origin ", *obj); - } - - return can_modref; - } - - // Computes a map of LoadInst -> boolean for a function indicating whether - // that load is "uncacheable". - // A load is considered "uncacheable" if the data at the loaded memory - // location can be modified after the load instruction. - std::map compute_uncacheable_load_map() { - std::map can_modref_map; - for (auto &B : *oldFunc) { - if (unnecessaryBlocks.count(&B)) - continue; - for (auto &inst : B) { - // For each load instruction, determine if it is uncacheable. - if (isa(&inst)) { - can_modref_map[&inst] = is_load_uncacheable(inst); - } else if (isNVLoad(&inst)) { - can_modref_map[&inst] = false; - } else if (auto II = dyn_cast(&inst)) { - switch (II->getIntrinsicID()) { - case Intrinsic::masked_load: - can_modref_map[&inst] = is_load_uncacheable(inst); - break; - default: - break; - } - } - } - } - return can_modref_map; - } - - std::pair> - compute_overwritten_args_for_one_callsite(CallInst *callsite_op) { - auto Fn = getFunctionFromCall(callsite_op); - if (!Fn) - return {}; - - StringRef funcName = getFuncNameFromCall(callsite_op); - - if (funcName == "llvm.julia.gc_preserve_begin") - return {}; - - if (funcName == "llvm.julia.gc_preserve_end") - return {}; - - if (funcName == "julia.pointer_from_objref") - return {}; - - if (funcName == "julia.gc_loaded") - return {}; - - if (funcName == "julia.write_barrier") - return {}; - - if (funcName == "julia.write_barrier_binding") - return {}; - - if (funcName == "julia.safepoint") - return {}; - - if (funcName == "enzyme_zerotype") - return {}; - - if (isMemFreeLibMFunction(funcName)) { - return {}; - } - - if (isDebugFunction(callsite_op->getCalledFunction())) - return {}; - - if (isCertainPrint(funcName) || isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - return {}; - } - - if (startsWith(funcName, "MPI_") || - startsWith(funcName, "enzyme_wrapmpi$$")) - return {}; - - if (funcName == "__kmpc_for_static_init_4" || - funcName == "__kmpc_for_static_init_4u" || - funcName == "__kmpc_for_static_init_8" || - funcName == "__kmpc_for_static_init_8u") { - return {}; - } - - SmallVector args; - SmallVector objs; - SmallVector args_safe; - - // First, we need to propagate the uncacheable status from the parent - // function to the callee. - // because memory location x modified after parent returns => x modified - // after callee returns. - for (unsigned i = 0; i < callsite_op->arg_size(); ++i) { - args.push_back(callsite_op->getArgOperand(i)); - - // If the UnderlyingObject is from one of this function's arguments, then - // we need to propagate the volatility. - Value *obj = getBaseObject(callsite_op->getArgOperand(i)); - - objs.push_back(obj); - - bool init_safe = !is_value_mustcache_from_origin(obj); - if (!init_safe) { - auto CD = TR.query(obj)[{-1}]; - if (CD == BaseType::Integer || CD.isFloat()) - init_safe = true; - } - if (!init_safe && !isa(obj) && !isa(obj) && - !isa(obj)) { - EmitWarning("UncacheableOrigin", *callsite_op, "Callsite ", - *callsite_op, " arg ", i, " ", - *callsite_op->getArgOperand(i), " uncacheable from origin ", - *obj); - } - args_safe.push_back(init_safe); - } - - bool next_subsequent_inst_may_write = subsequent_calls_may_write; - - // Second, we check for memory modifications that can occur in the - // continuation of the - // callee inside the parent function. - allFollowersOf(callsite_op, [&](Instruction *inst2) { - // Don't consider modref from malloc/free as a need to cache - if (auto obj_op = dyn_cast(inst2)) { - StringRef sfuncName = getFuncNameFromCall(obj_op); - - if (isMemFreeLibMFunction(sfuncName)) { - return false; - } - - if (isDebugFunction(obj_op->getCalledFunction())) - return false; - - if (isCertainPrint(sfuncName) || isAllocationFunction(sfuncName, TLI) || - isDeallocationFunction(sfuncName, TLI)) { - return false; - } - - if (sfuncName == "__kmpc_for_static_fini") { - return false; - } - - if (auto iasm = dyn_cast(obj_op->getCalledOperand())) { - if (StringRef(iasm->getAsmString()).contains("exit")) - return false; - } - } - - if (unnecessaryBlocks.count(inst2->getParent())) - return false; - - if (!inst2->mayWriteToMemory()) - return false; - - next_subsequent_inst_may_write = true; - for (unsigned i = 0; i < args.size(); ++i) { - if (!args_safe[i]) - continue; - - // Any use of an arg from a rematerializable allocation - // is definitionally reloadable in sub. - if (rematerializableAllocations.count(objs[i])) - continue; - - auto CD = TR.query(args[i])[{-1}]; - if (CD == BaseType::Integer || CD.isFloat()) - continue; - - if (llvm::isModSet(AA.getModRefInfo( - inst2, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { - if (!isa(callsite_op->getArgOperand(i)) && - !isa(callsite_op->getArgOperand(i))) - EmitWarning("UncacheableArg", *callsite_op, "Callsite ", - *callsite_op, " arg ", i, " ", - *callsite_op->getArgOperand(i), " uncacheable due to ", - *inst2); - args_safe[i] = false; - } - } - return false; - }); - - std::vector overwritten_args; - - if (funcName == "__kmpc_fork_call") { - Value *op = callsite_op->getArgOperand(2); - Function *task = nullptr; - while (!(task = dyn_cast(op))) { - if (auto castinst = dyn_cast(op)) - if (castinst->isCast()) { - op = castinst->getOperand(0); - continue; - } - if (auto CI = dyn_cast(op)) { - op = CI->getOperand(0); - continue; - } - llvm::errs() << "op: " << *op << "\n"; - assert(0 && "unknown fork call arg"); - } - - // Global.tid is cacheable - overwritten_args.push_back(false); - - // Bound.tid is cacheable - overwritten_args.push_back(false); - - // Ignore first three arguments of fork call - for (unsigned i = 3; i < args.size(); ++i) { - overwritten_args.push_back(!args_safe[i]); - } - } else { - for (unsigned i = 0; i < args.size(); ++i) { - overwritten_args.push_back(!args_safe[i]); - } - } - - return std::make_pair(next_subsequent_inst_may_write, overwritten_args); - } - - // Given a function and the arguments passed to it by its caller that are - // uncacheable (_overwritten_args) compute - // the set of uncacheable arguments for each callsite inside the function. A - // pointer argument is uncacheable at a callsite if the memory pointed to - // might be modified after that callsite. - std::map>> - compute_overwritten_args_for_callsites() { - std::map>> - overwritten_args_map; - - for (auto &B : *oldFunc) { - if (unnecessaryBlocks.count(&B)) - continue; - for (auto &inst : B) { - if (auto op = dyn_cast(&inst)) { - - // We do not need uncacheable args for intrinsic functions. So skip - // such callsites. - if (auto II = dyn_cast(&inst)) { - if (!startsWith(II->getCalledFunction()->getName(), "llvm.julia")) - continue; - } - - // For all other calls, we compute the uncacheable args for this - // callsite. - overwritten_args_map.insert( - std::pair>>( - op, compute_overwritten_args_for_one_callsite(op))); - } - } - } - return overwritten_args_map; - } -}; - -void calculateUnusedValuesInFunction( - Function &func, llvm::SmallPtrSetImpl &unnecessaryValues, - llvm::SmallPtrSetImpl &unnecessaryInstructions, - bool returnValue, DerivativeMode mode, GradientUtils *gutils, - TargetLibraryInfo &TLI, ArrayRef constant_args, - const llvm::SmallPtrSetImpl &oldUnreachable) { - std::map CacheResults; - for (auto pair : gutils->knownRecomputeHeuristic) { - if (!pair.second || - gutils->unnecessaryIntermediates.count(cast(pair.first))) { - CacheResults[UsageKey(pair.first, QueryType::Primal)] = false; - } - } - std::map PrimalSeen; - if (mode == DerivativeMode::ReverseModeGradient) { - PrimalSeen = CacheResults; - } - - for (const auto &pair : gutils->allocationsWithGuaranteedFree) { - if (pair.second.size() == 0) - continue; - - bool primalNeededInReverse = - DifferentialUseAnalysis::is_value_needed_in_reverse( - gutils, pair.first, mode, CacheResults, oldUnreachable); - - // If rematerializing a split or loop-level allocation, the primal - // allocation is not needed in the reverse. - if (gutils->needsCacheWholeAllocation(pair.first)) { - primalNeededInReverse = true; - } else if (primalNeededInReverse) { - auto found = gutils->rematerializableAllocations.find( - const_cast(pair.first)); - if (found != gutils->rematerializableAllocations.end()) { - if (mode != DerivativeMode::ReverseModeCombined) - primalNeededInReverse = false; - else if (auto inst = dyn_cast(pair.first)) - if (found->second.LI && - found->second.LI->contains(inst->getParent())) { - primalNeededInReverse = false; - } - } - } - - for (auto freeCall : pair.second) { - if (!primalNeededInReverse) - gutils->forwardDeallocations.insert(freeCall); - else - gutils->postDominatingFrees.insert(freeCall); - } - } - // Consider allocations which are being rematerialized, but do not - // have a guaranteed free. - for (const auto &rmat : gutils->rematerializableAllocations) { - if (isa(rmat.first) && - gutils->allocationsWithGuaranteedFree.count(cast(rmat.first))) - continue; - if (rmat.second.frees.size() == 0) - continue; - - bool primalNeededInReverse = - DifferentialUseAnalysis::is_value_needed_in_reverse( - gutils, rmat.first, mode, CacheResults, oldUnreachable); - // If rematerializing a split or loop-level allocation, the primal - // allocation is not needed in the reverse. - if (gutils->needsCacheWholeAllocation(rmat.first)) { - primalNeededInReverse = true; - } else if (primalNeededInReverse) { - if (mode != DerivativeMode::ReverseModeCombined) - primalNeededInReverse = false; - else if (auto inst = dyn_cast(rmat.first)) - if (rmat.second.LI && rmat.second.LI->contains(inst->getParent())) { - primalNeededInReverse = false; - } - } - - for (auto freeCall : rmat.second.frees) { - if (!primalNeededInReverse) - gutils->forwardDeallocations.insert(cast(freeCall)); - } - } - - std::function isNoNeed = [&](const llvm::Value - *v) { - auto Obj = getBaseObject(v); - if (Obj != v) - return isNoNeed(Obj); - if (auto C = dyn_cast(v)) - return isNoNeed(C->getOperand(0)); - else if (auto arg = dyn_cast(v)) { - auto act = constant_args[arg->getArgNo()]; - if (act == DIFFE_TYPE::DUP_NONEED) { - return true; - } - } else if (isa(v) || isAllocationCall(v, TLI)) { - if (!gutils->isConstantValue(const_cast(v))) { - std::set done; - std::deque todo = {v}; - bool legal = true; - while (todo.size()) { - const Value *cur = todo.back(); - todo.pop_back(); - if (done.count(cur)) - continue; - done.insert(cur); - - if (unnecessaryValues.count(cur)) - continue; - - for (auto u : cur->users()) { - if (auto SI = dyn_cast(u)) { - if (SI->getValueOperand() != cur) { - continue; - } - } - if (auto I = dyn_cast(u)) { - if (unnecessaryInstructions.count(I)) { - if (!DifferentialUseAnalysis::is_use_directly_needed_in_reverse( - gutils, cur, mode, I, oldUnreachable, - QueryType::Primal)) { - continue; - } - } - if (isDeallocationCall(I, TLI)) { - continue; - } - } - if (auto II = dyn_cast(u); - II && isIntelSubscriptIntrinsic(*II)) { - todo.push_back(&*u); - continue; - } else if (auto CI = dyn_cast(u)) { - if (getFuncNameFromCall(CI) == "julia.write_barrier") { - continue; - } - if (getFuncNameFromCall(CI) == "julia.write_barrier_binding") { - continue; - } - bool writeOnlyNoCapture = true; - if (shouldDisableNoWrite(CI)) { - writeOnlyNoCapture = false; - } - for (size_t i = 0; i < CI->arg_size(); i++) { - if (cur == CI->getArgOperand(i)) { - if (!isNoCapture(CI, i)) { - writeOnlyNoCapture = false; - break; - } - if (!isWriteOnly(CI, i)) { - writeOnlyNoCapture = false; - break; - } - } - } - // Don't need the primal argument if it is write only and - // not captured - if (writeOnlyNoCapture) { - continue; - } - } - if (isa(u) || isa(u) || - isa(u)) { - todo.push_back(&*u); - continue; - } else { - legal = false; - break; - } - } - } - if (legal) { - return true; - } - } - } else if (auto II = dyn_cast(v); - II && isIntelSubscriptIntrinsic(*II)) { - unsigned int ptrArgIdx = 3; - return isNoNeed(II->getOperand(ptrArgIdx)); - } - return false; - }; - - calculateUnusedValues( - func, unnecessaryValues, unnecessaryInstructions, returnValue, - [&](const Value *val) { - bool ivn = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, val, mode, PrimalSeen, oldUnreachable); - return ivn; - }, - [&](const Instruction *inst) { - if (auto II = dyn_cast(inst)) { - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end || - II->getIntrinsicID() == Intrinsic::stacksave || - II->getIntrinsicID() == Intrinsic::stackrestore) { - return UseReq::Cached; - } - } - - if (mode == DerivativeMode::ReverseModeGradient && - gutils->knownRecomputeHeuristic.find(inst) != - gutils->knownRecomputeHeuristic.end()) { - if (!gutils->knownRecomputeHeuristic[inst]) { - return UseReq::Cached; - } - } - - if (llvm::isa(inst) && returnValue) { - return UseReq::Need; - } - if (llvm::isa(inst) || - llvm::isa(inst)) { - size_t num = 0; - for (auto suc : successors(inst->getParent())) { - if (!oldUnreachable.count(suc)) { - num++; - } - } - if (num > 1 || mode != DerivativeMode::ReverseModeGradient) { - return UseReq::Need; - } - } - - // We still need this value if used as increment/induction variable for - // a loop - // TODO this really should be more simply replaced by doing the loop - // normalization on the original code as preprocessing - - // Below we specifically check if the instructions or any of its - // newly-generated (e.g. not in original function) uses are used in the - // loop calculation - auto NewI = gutils->getNewFromOriginal(inst); - std::set todo = {NewI}; - { - std::deque toAnalyze; - // Here we get the uses of inst from the original function - std::set UsesFromOrig; - for (auto u : inst->users()) { - if (auto I = dyn_cast(u)) { - UsesFromOrig.insert(gutils->getNewFromOriginal(I)); - } - } - // We only analyze uses that were not available in the original - // function - for (auto u : NewI->users()) { - if (auto I = dyn_cast(u)) { - if (UsesFromOrig.count(I) == 0) - toAnalyze.push_back(I); - } - } - - while (toAnalyze.size()) { - auto Next = toAnalyze.front(); - toAnalyze.pop_front(); - if (todo.count(Next)) - continue; - todo.insert(Next); - for (auto u : Next->users()) { - if (auto I = dyn_cast(u)) { - toAnalyze.push_back(I); - } - } - } - } - - for (auto I : todo) { - if (gutils->isInstructionUsedInLoopInduction(*I)) { - return UseReq::Need; - } - } - - bool mayWriteToMemory = inst->mayWriteToMemory(); - if (unnecessaryValues.count(inst) && isAllocationCall(inst, TLI)) - return UseReq::Recur; - - if (auto obj_op = dyn_cast(inst)) { - StringRef funcName = getFuncNameFromCall(obj_op); - if (isDeallocationFunction(funcName, TLI)) { - if (unnecessaryValues.count(obj_op->getArgOperand(0))) { - return UseReq::Recur; - } - - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError || - mode == DerivativeMode::ForwardModeSplit || - ((mode == DerivativeMode::ReverseModePrimal || - mode == DerivativeMode::ReverseModeCombined) && - gutils->forwardDeallocations.count(obj_op))) - return UseReq::Need; - return UseReq::Recur; - } - if (hasMetadata(obj_op, "enzyme_zerostack")) { - if (unnecessaryValues.count( - getBaseObject(obj_op->getArgOperand(0)))) { - return UseReq::Recur; - } - } - Intrinsic::ID ID = Intrinsic::not_intrinsic; - if (isMemFreeLibMFunction(funcName, &ID) || isReadOnly(obj_op)) { - mayWriteToMemory = false; - } - if (funcName == "memset" || funcName == "memset_pattern16" || - funcName == "memcpy" || funcName == "memmove") { - if (isNoNeed(obj_op->getArgOperand(0))) - return UseReq::Recur; - } - } - - if (auto si = dyn_cast(inst)) { - bool nnop = isNoNeed(si->getPointerOperand()); - if (isa(si->getValueOperand())) - return UseReq::Recur; - if (nnop) - return UseReq::Recur; - } - - if (auto msi = dyn_cast(inst)) { - if (isNoNeed(msi->getArgOperand(0))) - return UseReq::Recur; - } - - if (auto mti = dyn_cast(inst)) { - if (isNoNeed(mti->getArgOperand(0))) - return UseReq::Recur; - - auto at = getBaseObject(mti->getArgOperand(1)); - - bool newMemory = false; - if (isa(at)) - newMemory = true; - else if (isAllocationCall(at, TLI)) - newMemory = true; - if (newMemory) { - bool foundStore = false; - allInstructionsBetween( - *gutils->OrigLI, cast(at), - const_cast(mti), - [&](Instruction *I) -> bool { - if (!I->mayWriteToMemory()) - return /*earlyBreak*/ false; - if (unnecessaryInstructions.count(I)) - return /*earlyBreak*/ false; - if (auto CI = dyn_cast(I)) { - if (isReadOnly(CI)) - return /*earlyBreak*/ false; - } - - if (writesToMemoryReadBy( - &gutils->TR, *gutils->OrigAA, TLI, - /*maybeReader*/ const_cast(mti), - /*maybeWriter*/ I)) { - foundStore = true; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (!foundStore) { - return UseReq::Recur; - } - } - } - if ((mode == DerivativeMode::ReverseModePrimal || - mode == DerivativeMode::ReverseModeCombined || - mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError) && - mayWriteToMemory) { - return UseReq::Need; - } - // Don't erase any store that needs to be preserved for a - // rematerialization. However, if not used in a rematerialization, the - // store should be eliminated in the reverse pass - if (mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ForwardModeSplit) { - auto CI = dyn_cast(const_cast(inst)); - const Function *CF = CI ? getFunctionFromCall(CI) : nullptr; - StringRef funcName = CF ? CF->getName() : ""; - if (isa(inst) || isa(inst) || - isa(inst) || funcName == "julia.write_barrier" || - funcName == "julia.write_barrier_binding") { - for (auto pair : gutils->rematerializableAllocations) { - if (pair.second.stores.count(inst)) { - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, pair.first, mode, PrimalSeen, - oldUnreachable)) { - return UseReq::Need; - } - } - } - return UseReq::Recur; - } - } - - bool ivn = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, inst, mode, PrimalSeen, oldUnreachable); - if (ivn) { - return UseReq::Need; - } - return UseReq::Recur; - }, - [&](const Instruction *inst, const Value *val) { - if (isNoNeed(val)) { - if (auto SI = dyn_cast(inst)) - if (SI->getPointerOperand() == val) - return false; - - if (auto CI = dyn_cast(inst)) { - if (isDeallocationCall(CI, TLI)) { - if (CI->getArgOperand(0) == val) - return false; - } - - bool writeOnlyNoCapture = true; - if (shouldDisableNoWrite(CI)) { - writeOnlyNoCapture = false; - } - for (size_t i = 0; i < CI->arg_size(); i++) { - if (val == CI->getArgOperand(i)) { - if (!isNoCapture(CI, i)) { - writeOnlyNoCapture = false; - break; - } - if (!isWriteOnly(CI, i)) { - writeOnlyNoCapture = false; - break; - } - } - } - // Don't need the primal argument if it is write only and not - // captured - if (writeOnlyNoCapture) { - return false; - } - } - } - return true; - }); - if (EnzymePrintUnnecessary) { - llvm::errs() << " val use analysis of " << func.getName() - << ": mode=" << to_string(mode) << "\n"; - for (auto &BB : func) - for (auto &I : BB) { - bool ivn = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, &I, mode, PrimalSeen, oldUnreachable); - bool isn = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Shadow>(gutils, &I, mode, PrimalSeen, oldUnreachable); - llvm::errs() << I << " ivn=" << (int)ivn << " isn: " << (int)isn; - auto found = gutils->knownRecomputeHeuristic.find(&I); - if (found != gutils->knownRecomputeHeuristic.end()) { - llvm::errs() << " krc=" << (int)found->second; - } - llvm::errs() << "\n"; - } - llvm::errs() << "unnecessaryValues of " << func.getName() - << ": mode=" << to_string(mode) << "\n"; - for (auto a : unnecessaryValues) { - bool ivn = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(gutils, a, mode, PrimalSeen, oldUnreachable); - bool isn = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Shadow>(gutils, a, mode, PrimalSeen, oldUnreachable); - llvm::errs() << *a << " ivn=" << (int)ivn << " isn: " << (int)isn; - auto found = gutils->knownRecomputeHeuristic.find(a); - if (found != gutils->knownRecomputeHeuristic.end()) { - llvm::errs() << " krc=" << (int)found->second; - } - llvm::errs() << "\n"; - } - llvm::errs() << "unnecessaryInstructions " << func.getName() << ":\n"; - for (auto a : unnecessaryInstructions) { - llvm::errs() << *a << "\n"; - } - } -} - -void calculateUnusedStoresInFunction( - Function &func, - llvm::SmallPtrSetImpl &unnecessaryStores, - const llvm::SmallPtrSetImpl &unnecessaryInstructions, - GradientUtils *gutils, TargetLibraryInfo &TLI) { - calculateUnusedStores(func, unnecessaryStores, [&](const Instruction *inst) { - if (auto si = dyn_cast(inst)) { - if (isa(si->getValueOperand())) - return false; - } - - if (auto mti = dyn_cast(inst)) { - auto at = getBaseObject(mti->getArgOperand(1)); - bool newMemory = false; - if (isa(at)) - newMemory = true; - else if (isAllocationCall(at, TLI)) - newMemory = true; - if (newMemory) { - bool foundStore = false; - allInstructionsBetween( - *gutils->OrigLI, cast(at), - const_cast(mti), [&](Instruction *I) -> bool { - if (!I->mayWriteToMemory()) - return /*earlyBreak*/ false; - if (unnecessaryStores.count(I)) - return /*earlyBreak*/ false; - - // if (I == &MTI) return; - if (writesToMemoryReadBy( - &gutils->TR, *gutils->OrigAA, TLI, - /*maybeReader*/ const_cast(mti), - /*maybeWriter*/ I)) { - foundStore = true; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (!foundStore) { - // performing a memcpy out of unitialized memory - return false; - } - } - } - - return true; - }); -} - -std::string to_string(Function &F, const std::vector &us) { - std::string s = "{"; - auto arg = F.arg_begin(); - for (auto y : us) { - s += arg->getName().str() + "@" + F.getName().str() + ":" + - std::to_string(y) + ","; - arg++; - } - return s + "}"; -} - -//! assuming not top level -std::pair, SmallVector> -getDefaultFunctionTypeForAugmentation(FunctionType *called, bool returnUsed, - DIFFE_TYPE retType) { - SmallVector args; - SmallVector outs; - for (auto &argType : called->params()) { - args.push_back(argType); - - if (!argType->isFPOrFPVectorTy()) { - args.push_back(argType); - } - } - - auto ret = called->getReturnType(); - // TODO CONSIDER a.getType()->isIntegerTy() && - // cast(a.getType())->getBitWidth() < 16 - outs.push_back(getInt8PtrTy(called->getContext())); - if (!ret->isVoidTy() && !ret->isEmptyTy()) { - if (returnUsed) { - outs.push_back(ret); - } - if (retType == DIFFE_TYPE::DUP_ARG || retType == DIFFE_TYPE::DUP_NONEED) { - outs.push_back(ret); - } - } - - return std::pair, SmallVector>(args, outs); -} - -//! assuming not top level -std::pair, SmallVector> -getDefaultFunctionTypeForGradient(FunctionType *called, DIFFE_TYPE retType, - ArrayRef tys) { - SmallVector args; - SmallVector outs; - - size_t i = 0; - for (auto &argType : called->params()) { - args.push_back(argType); - - switch (tys[i]) { - case DIFFE_TYPE::CONSTANT: - break; - case DIFFE_TYPE::OUT_DIFF: - outs.push_back(argType); - break; - case DIFFE_TYPE::DUP_ARG: - case DIFFE_TYPE::DUP_NONEED: - args.push_back(argType); - break; - } - i++; - } - - auto ret = called->getReturnType(); - - if (retType == DIFFE_TYPE::OUT_DIFF) { - args.push_back(ret); - } - - return std::pair, SmallVector>(args, outs); -} - -//! assuming not top level -std::pair, SmallVector> -getDefaultFunctionTypeForGradient(FunctionType *called, DIFFE_TYPE retType) { - SmallVector act; - for (auto &argType : called->params()) { - - if (argType->isFPOrFPVectorTy()) { - act.push_back(DIFFE_TYPE::OUT_DIFF); - } else { - act.push_back(DIFFE_TYPE::DUP_ARG); - } - } - return getDefaultFunctionTypeForGradient(called, retType, act); -} - -bool shouldAugmentCall(CallInst *op, const GradientUtils *gutils) { - assert(op->getParent()->getParent() == gutils->oldFunc); - - Function *called = op->getCalledFunction(); - - bool modifyPrimal = !called || !isReadNone(op); - - if (modifyPrimal) { -#ifdef PRINT_AUGCALL - if (called) - llvm::errs() << "primal modified " << called->getName() - << " modified via reading from memory" - << "\n"; - else - llvm::errs() << "primal modified " << *op->getCalledValue() - << " modified via reading from memory" - << "\n"; -#endif - } - - if (!op->getType()->isFPOrFPVectorTy() && !gutils->isConstantValue(op) && - gutils->TR.anyPointer(op)) { - modifyPrimal = true; - -#ifdef PRINT_AUGCALL - if (called) - llvm::errs() << "primal modified " << called->getName() - << " modified via return" - << "\n"; - else - llvm::errs() << "primal modified " << *op->getCalledValue() - << " modified via return" - << "\n"; -#endif - } - - if (!called || called->empty()) - modifyPrimal = true; - - for (unsigned i = 0; i < op->arg_size(); ++i) { - if (gutils->isConstantValue(op->getArgOperand(i)) && called && - !called->empty()) { - continue; - } - - auto argType = op->getArgOperand(i)->getType(); - - if (!argType->isFPOrFPVectorTy() && - !gutils->isConstantValue(op->getArgOperand(i)) && - gutils->TR.anyPointer(op->getArgOperand(i))) { - if (!isReadOnly(op, i)) { - modifyPrimal = true; -#ifdef PRINT_AUGCALL - if (called) - llvm::errs() << "primal modified " << called->getName() - << " modified via arg " << i << "\n"; - else - llvm::errs() << "primal modified " << *op->getCalledValue() - << " modified via arg " << i << "\n"; -#endif - } - } - } - - // Don't need to augment calls that are certain to not hit return - if (isa(op->getParent()->getTerminator())) { - modifyPrimal = false; - } - -#ifdef PRINT_AUGCALL - llvm::errs() << "PM: " << *op << " modifyPrimal: " << modifyPrimal - << " cv: " << gutils->isConstantValue(op) << "\n"; -#endif - return modifyPrimal; -} - -bool legalCombinedForwardReverse( - CallInst *origop, - const std::map &replacedReturns, - SmallVectorImpl &postCreate, - SmallVectorImpl &userReplace, const GradientUtils *gutils, - const SmallPtrSetImpl &unnecessaryInstructions, - const SmallPtrSetImpl &oldUnreachable, - const bool subretused) { - Function *called = origop->getCalledFunction(); - Value *calledValue = origop->getCalledOperand(); - - if (isa(origop->getType())) { - bool sret = subretused; - if (!sret && !gutils->isConstantValue(origop)) { - sret = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Shadow>(gutils, origop, gutils->mode, oldUnreachable); - } - - if (sret) { - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [not implemented] pointer return for combined " - "forward/reverse " - << called->getName() << "\n"; - else - llvm::errs() << " [not implemented] pointer return for combined " - "forward/reverse " - << *calledValue << "\n"; - } - return false; - } - } - - // Check any users of the returned value and determine all values that would - // be needed to be moved to reverse pass - // to ensure the forward pass would remain correct and everything computable - SmallPtrSet usetree; - std::deque todo{origop}; - - bool legal = true; - - // Given a function I we know must be moved to the reverse for legality - // reasons - auto propagate = [&](Instruction *I) { - // if only used in unneeded return, don't need to move this to reverse - // (unless this is the original function) - if (usetree.count(I)) - return; - if (gutils->notForAnalysis.count(I->getParent())) - return; - if (auto ri = dyn_cast(I)) { - auto find = replacedReturns.find(ri); - if (find != replacedReturns.end()) { - usetree.insert(ri); - } - return; - } - - if (isa(I) || isa(I)) { - legal = false; - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [bi] failed to replace function " - << (called->getName()) << " due to " << *I << "\n"; - else - llvm::errs() << " [bi] failed to replace function " << (*calledValue) - << " due to " << *I << "\n"; - } - return; - } - - // Even though the value `I` depends on (perhaps indirectly) the call being - // checked for, if neither `I` nor its pointer-valued shadow are used in the - // reverse pass, we can ignore the dependency as long as `I` is not going to - // have a combined forward and reverse pass. - if (I != origop && unnecessaryInstructions.count(I)) { - bool needShadow = false; - if (!gutils->isConstantValue(I)) { - needShadow = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Shadow>(gutils, I, DerivativeMode::ReverseModeCombined, - oldUnreachable); - } - if (!needShadow) { - if (gutils->isConstantInstruction(I) || !isa(I)) { - userReplace.push_back(I); - return; - } - } - } - - if (isAllocationCall(I, gutils->TLI) || - isDeallocationCall(I, gutils->TLI)) { - return; - } - - if (isa(I)) { - legal = false; - - return; - } - if (isa(I)) { - legal = false; - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [phi] failed to replace function " - << (called->getName()) << " due to " << *I << "\n"; - else - llvm::errs() << " [phi] failed to replace function " << (*calledValue) - << " due to " << *I << "\n"; - } - return; - } - if (!I->getType()->isVoidTy() && - DifferentialUseAnalysis::is_value_needed_in_reverse( - gutils, I, DerivativeMode::ReverseModeCombined, oldUnreachable)) { - legal = false; - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [nv] failed to replace function " - << (called->getName()) << " due to " << *I << "\n"; - else - llvm::errs() << " [nv] failed to replace function " << (*calledValue) - << " due to " << *I << "\n"; - } - return; - } - if (!I->getType()->isVoidTy() && - gutils->TR.query(I)[{-1}].isPossiblePointer() && - DifferentialUseAnalysis::is_value_needed_in_reverse( - gutils, I, DerivativeMode::ReverseModeCombined, oldUnreachable)) { - legal = false; - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [ns] failed to replace function " - << (called->getName()) << " due to " << *I << "\n"; - else - llvm::errs() << " [ns] failed to replace function " << (*calledValue) - << " due to " << *I << "\n"; - } - return; - } - if (I != origop && !isa(I) && isa(I)) { - legal = false; - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [ci] failed to replace function " - << (called->getName()) << " due to " << *I << "\n"; - else - llvm::errs() << " [ci] failed to replace function " << (*calledValue) - << " due to " << *I << "\n"; - } - return; - } - // Do not try moving an instruction that modifies memory, if we already - // moved it. We need the originalToNew check because we may have deleted - // the instruction, which wont require the failed to move. - if (!isa(I) || unnecessaryInstructions.count(I) == 0) - if (I->mayReadOrWriteMemory() && - gutils->originalToNewFn.find(I) != gutils->originalToNewFn.end() && - gutils->getNewFromOriginal(I)->getParent() != - gutils->getNewFromOriginal(I->getParent())) { - legal = false; - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [am] failed to replace function " - << (called->getName()) << " due to " << *I << "\n"; - else - llvm::errs() << " [am] failed to replace function " - << (*calledValue) << " due to " << *I << "\n"; - } - return; - } - - usetree.insert(I); - for (auto use : I->users()) { - todo.push_back(cast(use)); - } - }; - - while (!todo.empty()) { - auto inst = todo.front(); - todo.pop_front(); - - if (inst->mayWriteToMemory()) { - auto consider = [&](Instruction *user) { - if (!user->mayReadFromMemory()) - return false; - if (writesToMemoryReadBy(&gutils->TR, *gutils->OrigAA, gutils->TLI, - /*maybeReader*/ user, - /*maybeWriter*/ inst)) { - - propagate(user); - // Fast return if not legal - if (!legal) - return true; - } - return false; - }; - allFollowersOf(inst, consider); - if (!legal) - return false; - } - - propagate(inst); - if (!legal) - return false; - } - - // Check if any of the unmoved operations will make it illegal to move the - // instruction - - for (auto inst : usetree) { - if (!inst->mayReadFromMemory()) - continue; - allFollowersOf(inst, [&](Instruction *post) { - if (unnecessaryInstructions.count(post)) - return false; - if (!post->mayWriteToMemory()) - return false; - - if (writesToMemoryReadBy(&gutils->TR, *gutils->OrigAA, gutils->TLI, - /*maybeReader*/ inst, - /*maybeWriter*/ post)) { - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [mem] failed to replace function " - << (called->getName()) << " due to " << *post - << " usetree: " << *inst << "\n"; - else - llvm::errs() << " [mem] failed to replace function " - << (*calledValue) << " due to " << *post - << " usetree: " << *inst << "\n"; - } - legal = false; - return true; - } - return false; - }); - if (!legal) - break; - } - - allFollowersOf(origop, [&](Instruction *post) { - if (unnecessaryInstructions.count(post)) - return false; - if (!origop->mayWriteToMemory() && !origop->mayReadFromMemory()) - return false; - if (auto CI = dyn_cast(post)) { - bool noFree = false; - noFree |= CI->hasFnAttr(Attribute::NoFree); - auto called = getFunctionFromCall(CI); - StringRef funcName = getFuncNameFromCall(CI); - if (funcName == "llvm.trap") - noFree = true; - if (!noFree && called) { - noFree |= called->hasFnAttribute(Attribute::NoFree); - } - if (!noFree) { - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [freeing] failed to replace function " - << (called->getName()) << " due to freeing " << *post - << " usetree: " << *origop << "\n"; - else - llvm::errs() << " [freeing] failed to replace function " - << (*calledValue) << " due to freeing " << *post - << " usetree: " << *origop << "\n"; - } - legal = false; - return true; - } - } - return false; - }); - - if (!legal) - return false; - - allFollowersOf(origop, [&](Instruction *inst) { - if (auto ri = dyn_cast(inst)) { - auto find = replacedReturns.find(ri); - if (find != replacedReturns.end()) { - postCreate.push_back(find->second); - return false; - } - } - - if (usetree.count(inst) == 0) - return false; - if (inst->getParent() != origop->getParent()) { - // Don't move a writing instruction (may change speculatable/etc things) - if (inst->mayWriteToMemory()) { - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [nonspec] failed to replace function " - << (called->getName()) << " due to " << *inst << "\n"; - else - llvm::errs() << " [nonspec] failed to replace function " - << (*calledValue) << " due to " << *inst << "\n"; - } - legal = false; - // Early exit - return true; - } - } - if (isa(inst) && - gutils->originalToNewFn.find(inst) == gutils->originalToNewFn.end()) { - legal = false; - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " [premove] failed to replace function " - << (called->getName()) << " due to " << *inst << "\n"; - else - llvm::errs() << " [premove] failed to replace function " - << (*calledValue) << " due to " << *inst << "\n"; - } - // Early exit - return true; - } - postCreate.push_back(gutils->getNewFromOriginal(inst)); - return false; - }); - - if (!legal) - return false; - - if (EnzymePrintPerf) { - if (called) - llvm::errs() << " choosing to replace function " << (called->getName()) - << " and do both forward/reverse\n"; - else - llvm::errs() << " choosing to replace function " << (*calledValue) - << " and do both forward/reverse\n"; - } - - return true; -} - -void clearFunctionAttributes(Function *f) { - for (Argument &Arg : f->args()) { - if (Arg.hasAttribute(Attribute::Returned)) - Arg.removeAttr(Attribute::Returned); - if (Arg.hasAttribute(Attribute::StructRet)) - Arg.removeAttr(Attribute::StructRet); - } - - Attribute::AttrKind fnattrs[] = { -#if LLVM_VERSION_MAJOR >= 16 - Attribute::Memory, -#endif - Attribute::ReadOnly, - Attribute::ReadNone, - Attribute::WriteOnly, - Attribute::WillReturn, - Attribute::OptimizeNone - }; - for (auto attr : fnattrs) { - if (f->hasFnAttribute(attr)) { - f->removeFnAttr(attr); - } - } - - if (f->getAttributes().getRetDereferenceableBytes()) { - f->removeRetAttr(Attribute::Dereferenceable); - } - - if (f->getAttributes().getRetAlignment()) { - f->removeRetAttr(Attribute::Alignment); - } - Attribute::AttrKind attrs[] = { -#if LLVM_VERSION_MAJOR >= 19 - Attribute::Range, -#endif -#if LLVM_VERSION_MAJOR >= 17 - Attribute::NoFPClass, -#endif - Attribute::NoUndef, - Attribute::NonNull, - Attribute::ZExt, - Attribute::SExt, - Attribute::NoAlias - }; - for (auto attr : attrs) { - if (f->hasRetAttribute(attr)) { - f->removeRetAttr(attr); - } - } - for (auto attr : {"enzyme_inactive", "enzyme_type"}) { - if (f->getAttributes().hasRetAttr(attr)) { - f->removeRetAttr(attr); - } - } -} - -void cleanupInversionAllocs(DiffeGradientUtils *gutils, BasicBlock *entry) { - while (gutils->inversionAllocs->size() > 0) { - Instruction *inst = &gutils->inversionAllocs->back(); - if (isa(inst)) - inst->moveBefore(&gutils->newFunc->getEntryBlock().front()); - else - inst->moveBefore(entry->getFirstNonPHIOrDbgOrLifetime()); - } - - (IRBuilder<>(gutils->inversionAllocs)).CreateUnreachable(); - DeleteDeadBlock(gutils->inversionAllocs); - for (auto BBs : gutils->reverseBlocks) { - if (pred_begin(BBs.second.front()) == pred_end(BBs.second.front())) { - (IRBuilder<>(BBs.second.front())).CreateUnreachable(); - DeleteDeadBlock(BBs.second.front()); - } - } -} - -void restoreCache( - DiffeGradientUtils *gutils, - const std::map, int> &mapping, - const SmallPtrSetImpl &guaranteedUnreachable) { - // One must use this temporary map to first create all the replacements - // prior to actually replacing to ensure that getSubLimits has the same - // behavior and unwrap behavior for all replacements. - SmallVector, 4> newIToNextI; - - for (const auto &m : mapping) { - if (m.first.second == CacheType::Self && - gutils->knownRecomputeHeuristic.count(m.first.first)) { - assert(gutils->knownRecomputeHeuristic.count(m.first.first)); - if (!isa(m.first.first)) { - auto newi = gutils->getNewFromOriginal(m.first.first); - if (auto PN = dyn_cast(newi)) - if (gutils->fictiousPHIs.count(PN)) { - assert(gutils->fictiousPHIs[PN] == m.first.first); - gutils->fictiousPHIs.erase(PN); - } - IRBuilder<> BuilderZ(newi->getNextNode()); - if (isa(m.first.first)) { - BuilderZ.SetInsertPoint( - cast(newi)->getParent()->getFirstNonPHI()); - } - Value *nexti = gutils->cacheForReverse(BuilderZ, newi, m.second, - /*replace*/ false); - newIToNextI.emplace_back(newi, nexti); - } else { - auto newi = gutils->getNewFromOriginal((Value *)m.first.first); - newIToNextI.emplace_back(newi, newi); - } - } - } - - std::map> unwrapToOrig; - for (auto pair : gutils->unwrappedLoads) - unwrapToOrig[pair.second].push_back(const_cast(pair.first)); - gutils->unwrappedLoads.clear(); - - for (auto pair : newIToNextI) { - auto newi = pair.first; - auto nexti = pair.second; - if (newi != nexti) { - gutils->replaceAWithB(newi, nexti); - } - } - - // This most occur after all the replacements have been made - // in the previous loop, lest a loop bound being unwrapped use - // a value being replaced. - for (auto pair : newIToNextI) { - auto newi = pair.first; - auto nexti = pair.second; - for (auto V : unwrapToOrig[newi]) { - ValueToValueMapTy available; - if (auto MD = hasMetadata(V, "enzyme_available")) { - for (auto &pair : MD->operands()) { - auto tup = cast(pair); - auto val = cast(tup->getOperand(1))->getValue(); - assert(val); - available[cast(tup->getOperand(0))->getValue()] = - val; - } - } - IRBuilder<> lb(V); - // This must disallow caching here as otherwise performing the loop in - // the wrong order may result in first replacing the later unwrapped - // value, caching it, then attempting to reuse it for an earlier - // replacement. - Value *nval = gutils->unwrapM(nexti, lb, available, - UnwrapMode::LegalFullUnwrapNoTapeReplace, - /*scope*/ nullptr, /*permitCache*/ false); - assert(nval); - V->replaceAllUsesWith(nval); - V->eraseFromParent(); - } - } - - // Erasure happens after to not erase the key of unwrapToOrig - for (auto pair : newIToNextI) { - auto newi = pair.first; - auto nexti = pair.second; - if (newi != nexti) { - if (auto inst = dyn_cast(newi)) - gutils->erase(inst); - } - } - - // TODO also can consider switch instance as well - // TODO can also insert to topLevel as well [note this requires putting the - // intrinsic at the correct location] - for (auto &BB : *gutils->oldFunc) { - SmallVector unreachables; - SmallVector reachables; - for (auto Succ : successors(&BB)) { - if (guaranteedUnreachable.find(Succ) != guaranteedUnreachable.end()) { - unreachables.push_back(Succ); - } else { - reachables.push_back(Succ); - } - } - - if (unreachables.size() == 0 || reachables.size() == 0) - continue; - - if (auto bi = dyn_cast(BB.getTerminator())) { - - Value *condition = gutils->getNewFromOriginal(bi->getCondition()); - - Constant *repVal = (bi->getSuccessor(0) == unreachables[0]) - ? ConstantInt::getFalse(condition->getContext()) - : ConstantInt::getTrue(condition->getContext()); - - for (auto UI = condition->use_begin(), E = condition->use_end(); - UI != E;) { - Use &U = *UI; - ++UI; - U.set(repVal); - } - } - if (reachables.size() == 1) - if (auto si = dyn_cast(BB.getTerminator())) { - Value *condition = gutils->getNewFromOriginal(si->getCondition()); - - Constant *repVal = nullptr; - if (si->getDefaultDest() == reachables[0]) { - std::set cases; - for (auto c : si->cases()) { - // TODO this doesnt work with unsigned 64 bit ints or higher - // integer widths - cases.insert(cast(c.getCaseValue())->getSExtValue()); - } - int64_t legalNot = 0; - while (cases.count(legalNot)) - legalNot++; - repVal = ConstantInt::getSigned(condition->getType(), legalNot); - cast(gutils->getNewFromOriginal(si)) - ->setCondition(repVal); - // knowing which input was provided for the default dest is not - // possible at compile time, give up on other use replacement - continue; - } else { - for (auto c : si->cases()) { - if (c.getCaseSuccessor() == reachables[0]) { - repVal = c.getCaseValue(); - } - } - } - assert(repVal); - for (auto UI = condition->use_begin(), E = condition->use_end(); - UI != E;) { - Use &U = *UI; - ++UI; - U.set(repVal); - } - } - } -} - -//! return structtype if recursive function -const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( - RequestContext context, Function *todiff, DIFFE_TYPE retType, - ArrayRef constant_args, TypeAnalysis &TA, bool returnUsed, - bool shadowReturnUsed, const FnTypeInfo &oldTypeInfo_, - bool subsequent_calls_may_write, const std::vector _overwritten_args, - bool forceAnonymousTape, bool runtimeActivity, unsigned width, - bool AtomicAdd, bool omp) { - - TimeTraceScope timeScope("CreateAugmentedPrimal", todiff->getName()); - - if (returnUsed) - assert(!todiff->getReturnType()->isEmptyTy() && - !todiff->getReturnType()->isVoidTy()); - if (retType != DIFFE_TYPE::CONSTANT) - assert(!todiff->getReturnType()->isEmptyTy() && - !todiff->getReturnType()->isVoidTy()); - - FnTypeInfo oldTypeInfo = preventTypeAnalysisLoops(oldTypeInfo_, todiff); - AugmentedCacheKey tup = {todiff, - retType, - constant_args, - subsequent_calls_may_write, - _overwritten_args, - returnUsed, - shadowReturnUsed, - oldTypeInfo, - forceAnonymousTape, - AtomicAdd, - omp, - width, - runtimeActivity}; - - if (_overwritten_args.size() != todiff->arg_size()) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << " overwritten_args.size() [" << _overwritten_args.size() - << "] != todiff->arg_size()\n"; - ss << "todiff: " << *todiff << "\n"; - if (context.req) { - ss << " at context: " << *context.req; - } else { - ss << *todiff << "\n"; - } - if (EmitNoDerivativeError(ss.str(), todiff, context)) { - auto newFunc = todiff; - std::map returnMapping; - returnMapping[AugmentedStruct::Return] = -1; - return insert_or_assign( - AugmentedCachedFunctions, tup, - AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, - constant_args, shadowReturnUsed)) - ->second; - } - llvm::errs() << "mod: " << *todiff->getParent() << "\n"; - llvm::errs() << *todiff << "\n"; - llvm_unreachable( - "attempting to differentiate function with wrong overwritten count"); - } - - assert(_overwritten_args.size() == todiff->arg_size()); - assert(constant_args.size() == todiff->getFunctionType()->getNumParams()); - - auto found = AugmentedCachedFunctions.find(tup); - if (found != AugmentedCachedFunctions.end()) { - return found->second; - } - TargetLibraryInfo &TLI = PPC.FAM.getResult(*todiff); - - // TODO make default typing (not just constant) - - if (auto md = hasMetadata(todiff, "enzyme_augment")) { - if (!isa(md)) { - llvm::errs() << *todiff << "\n"; - llvm::errs() << *md << "\n"; - report_fatal_error( - "unknown augment for noninvertible function -- metadata incorrect"); - } - auto md2 = cast(md); - assert(md2->getNumOperands() == 1); - auto gvemd = cast(md2->getOperand(0)); - auto foundcalled = cast(gvemd->getValue()); - - bool hasconstant = false; - for (auto v : constant_args) { - if (v == DIFFE_TYPE::CONSTANT) { - hasconstant = true; - break; - } - } - - if (hasconstant) { - EmitWarning("NoCustom", *todiff, - "Massaging provided custom augmented forward pass to handle " - "constant argumented"); - SmallVector dupargs; - std::vector next_constant_args(constant_args.begin(), - constant_args.end()); - { - auto OFT = todiff->getFunctionType(); - for (size_t act_idx = 0; act_idx < constant_args.size(); act_idx++) { - dupargs.push_back(OFT->getParamType(act_idx)); - switch (constant_args[act_idx]) { - case DIFFE_TYPE::OUT_DIFF: - break; - case DIFFE_TYPE::DUP_ARG: - case DIFFE_TYPE::DUP_NONEED: - dupargs.push_back(OFT->getParamType(act_idx)); - break; - case DIFFE_TYPE::CONSTANT: - if (!OFT->getParamType(act_idx)->isFPOrFPVectorTy()) { - next_constant_args[act_idx] = DIFFE_TYPE::DUP_ARG; - } else { - next_constant_args[act_idx] = DIFFE_TYPE::OUT_DIFF; - } - break; - } - } - } - - auto &aug = CreateAugmentedPrimal( - context, todiff, retType, next_constant_args, TA, returnUsed, - shadowReturnUsed, oldTypeInfo_, subsequent_calls_may_write, - _overwritten_args, forceAnonymousTape, runtimeActivity, width, - AtomicAdd, omp); - - FunctionType *FTy = - FunctionType::get(aug.fn->getReturnType(), dupargs, - todiff->getFunctionType()->isVarArg()); - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixaugment_" + todiff->getName(), todiff->getParent()); - - BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - auto arg = NewF->arg_begin(); - SmallVector fwdargs; - int act_idx = 0; - while (arg != NewF->arg_end()) { - arg->setName("arg" + Twine(act_idx)); - fwdargs.push_back(arg); - switch (constant_args[act_idx]) { - case DIFFE_TYPE::OUT_DIFF: - break; - case DIFFE_TYPE::DUP_ARG: - case DIFFE_TYPE::DUP_NONEED: - arg++; - arg->setName("arg" + Twine(act_idx) + "'"); - fwdargs.push_back(arg); - break; - case DIFFE_TYPE::CONSTANT: - if (next_constant_args[act_idx] != DIFFE_TYPE::OUT_DIFF) { - fwdargs.push_back(arg); - } - break; - } - arg++; - act_idx++; - } - auto cal = bb.CreateCall(aug.fn, fwdargs); - cal->setCallingConv(aug.fn->getCallingConv()); - - if (NewF->getReturnType()->isEmptyTy()) - bb.CreateRet(UndefValue::get(NewF->getReturnType())); - else if (NewF->getReturnType()->isVoidTy()) - bb.CreateRetVoid(); - else - bb.CreateRet(cal); - - return insert_or_assign( - AugmentedCachedFunctions, tup, - AugmentedReturn(NewF, aug.tapeType, aug.tapeIndices, - aug.returns, aug.overwritten_args_map, - aug.can_modref_map, next_constant_args, - shadowReturnUsed)) - ->second; - } - - if (foundcalled->hasStructRetAttr() && !todiff->hasStructRetAttr()) { - SmallVector args; - Type *sretTy = nullptr; - { - size_t i = 0; - for (auto &arg : foundcalled->args()) { - if (!foundcalled->hasParamAttribute(i, Attribute::StructRet)) - args.push_back(arg.getType()); - else { - sretTy = foundcalled->getParamAttribute(0, Attribute::StructRet) - .getValueAsType(); - } - i++; - } - } - assert(foundcalled->getReturnType()->isVoidTy()); - FunctionType *FTy = FunctionType::get( - sretTy, args, foundcalled->getFunctionType()->isVarArg()); - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixaugment_" + foundcalled->getName(), foundcalled->getParent()); - - BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - auto AI = bb.CreateAlloca(sretTy); - SmallVector argVs; - auto arg = NewF->arg_begin(); - size_t realidx = 0; - for (size_t i = 0; i < foundcalled->arg_size(); i++) { - if (!foundcalled->hasParamAttribute(i, Attribute::StructRet)) { - arg->setName("arg" + Twine(realidx)); - realidx++; - argVs.push_back(arg); - ++arg; - } else - argVs.push_back(AI); - } - auto cal = bb.CreateCall(foundcalled, argVs); - cal->setCallingConv(foundcalled->getCallingConv()); - - Value *res = bb.CreateLoad(sretTy, AI); - bb.CreateRet(res); - - todiff->setMetadata( - "enzyme_augment", - llvm::MDTuple::get(todiff->getContext(), - {llvm::ValueAsMetadata::get(NewF)})); - foundcalled = NewF; - } - - if (foundcalled->getReturnType() == todiff->getReturnType()) { - std::map returnMapping; - returnMapping[AugmentedStruct::Return] = -1; - return insert_or_assign( - AugmentedCachedFunctions, tup, - AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, - {}, constant_args, shadowReturnUsed)) - ->second; - } - - if (auto ST = dyn_cast(foundcalled->getReturnType())) { - if (ST->getNumElements() == 3) { - std::map returnMapping; - returnMapping[AugmentedStruct::Tape] = 0; - returnMapping[AugmentedStruct::Return] = 1; - returnMapping[AugmentedStruct::DifferentialReturn] = 2; - if (ST->getTypeAtIndex(1) != todiff->getReturnType() || - ST->getTypeAtIndex(2) != todiff->getReturnType()) { - Type *retTys[] = {ST->getTypeAtIndex((unsigned)0), - todiff->getReturnType(), todiff->getReturnType()}; - auto RT = - StructType::get(ST->getContext(), retTys, /*isPacked*/ false); - FunctionType *FTy = - FunctionType::get(RT, foundcalled->getFunctionType()->params(), - foundcalled->getFunctionType()->isVarArg()); - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixaugment_" + foundcalled->getName(), foundcalled->getParent()); - - BasicBlock *BB = - BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - SmallVector argVs; - size_t realidx = 0; - for (auto &a : NewF->args()) { - a.setName("arg" + Twine(realidx)); - realidx++; - argVs.push_back(&a); - } - auto cal = bb.CreateCall(foundcalled, argVs); - cal->setCallingConv(foundcalled->getCallingConv()); - - Value *res = UndefValue::get(RT); - res = bb.CreateInsertValue(res, bb.CreateExtractValue(cal, {0}), {0}); - for (unsigned i = 1; i <= 2; i++) { - auto AI = bb.CreateAlloca(todiff->getReturnType()); - bb.CreateStore( - bb.CreateExtractValue(cal, {i}), - bb.CreatePointerCast( - AI, PointerType::getUnqual(ST->getTypeAtIndex(i)))); - Value *vres = bb.CreateLoad(todiff->getReturnType(), AI); - res = bb.CreateInsertValue(res, vres, {i}); - } - bb.CreateRet(res); - - todiff->setMetadata( - "enzyme_augment", - llvm::MDTuple::get(todiff->getContext(), - {llvm::ValueAsMetadata::get(NewF)})); - foundcalled = NewF; - } - return insert_or_assign( - AugmentedCachedFunctions, tup, - AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, - {}, constant_args, shadowReturnUsed)) - ->second; - } - if (ST->getNumElements() == 2 && - ST->getElementType(0) == ST->getElementType(1)) { - std::map returnMapping; - returnMapping[AugmentedStruct::Return] = 0; - returnMapping[AugmentedStruct::DifferentialReturn] = 1; - return insert_or_assign( - AugmentedCachedFunctions, tup, - AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, - {}, constant_args, shadowReturnUsed)) - ->second; - } - if (ST->getNumElements() == 2) { - std::map returnMapping; - returnMapping[AugmentedStruct::Tape] = 0; - returnMapping[AugmentedStruct::Return] = 1; - if (ST->getTypeAtIndex(1) != todiff->getReturnType()) { - Type *retTys[] = {ST->getTypeAtIndex((unsigned)0), - todiff->getReturnType()}; - auto RT = - StructType::get(ST->getContext(), retTys, /*isPacked*/ false); - FunctionType *FTy = - FunctionType::get(RT, foundcalled->getFunctionType()->params(), - foundcalled->getFunctionType()->isVarArg()); - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixaugment_" + foundcalled->getName(), foundcalled->getParent()); - - BasicBlock *BB = - BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - SmallVector argVs; - size_t realidx = 0; - for (auto &a : NewF->args()) { - a.setName("arg" + Twine(realidx)); - realidx++; - argVs.push_back(&a); - } - auto cal = bb.CreateCall(foundcalled, argVs); - cal->setCallingConv(foundcalled->getCallingConv()); - - Value *res = UndefValue::get(RT); - res = bb.CreateInsertValue(res, bb.CreateExtractValue(cal, {0}), {0}); - for (unsigned i = 1; i <= 1; i++) { - auto AI = bb.CreateAlloca(todiff->getReturnType()); - bb.CreateStore( - bb.CreateExtractValue(cal, {i}), - bb.CreatePointerCast( - AI, PointerType::getUnqual(ST->getTypeAtIndex(i)))); - Value *vres = bb.CreateLoad(todiff->getReturnType(), AI); - res = bb.CreateInsertValue(res, vres, {i}); - } - bb.CreateRet(res); - - todiff->setMetadata( - "enzyme_augment", - llvm::MDTuple::get(todiff->getContext(), - {llvm::ValueAsMetadata::get(NewF)})); - foundcalled = NewF; - } - return insert_or_assign( - AugmentedCachedFunctions, tup, - AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, - {}, constant_args, shadowReturnUsed)) - ->second; - } - } - - std::map returnMapping; - if (!foundcalled->getReturnType()->isVoidTy()) - returnMapping[AugmentedStruct::Tape] = -1; - - return insert_or_assign( - AugmentedCachedFunctions, tup, - AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, {}, - constant_args, shadowReturnUsed)) - ->second; // dyn_cast(st->getElementType(0))); - } - - std::map returnMapping; - - GradientUtils *gutils = GradientUtils::CreateFromClone( - *this, runtimeActivity, width, todiff, TLI, TA, oldTypeInfo, retType, - constant_args, - /*returnUsed*/ returnUsed, /*shadowReturnUsed*/ shadowReturnUsed, - returnMapping, omp); - - if (todiff->empty()) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No augmented forward pass found for " + todiff->getName(); - { - std::string demangledName = llvm::demangle(todiff->getName().str()); - // replace all '> >' with '>>' - size_t start = 0; - while ((start = demangledName.find("> >", start)) != std::string::npos) { - demangledName.replace(start, 3, ">>"); - } - if (demangledName != todiff->getName()) - ss << "(" << demangledName << ")"; - } - ss << "\n"; - if (context.req) { - ss << " at context: " << *context.req; - } else { - ss << *todiff << "\n"; - } - (IRBuilder<>(gutils->inversionAllocs)).CreateUnreachable(); - DeleteDeadBlock(gutils->inversionAllocs); - clearFunctionAttributes(gutils->newFunc); - if (EmitNoDerivativeError(ss.str(), todiff, context)) { - auto newFunc = gutils->newFunc; - delete gutils; - IRBuilder<> b(&*newFunc->getEntryBlock().begin()); - RequestContext context2{nullptr, &b}; - EmitNoDerivativeError(ss.str(), todiff, context2); - return insert_or_assign( - AugmentedCachedFunctions, tup, - AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, - constant_args, shadowReturnUsed)) - ->second; - } - llvm::errs() << "mod: " << *todiff->getParent() << "\n"; - llvm::errs() << *todiff << "\n"; - llvm_unreachable("attempting to differentiate function without definition"); - } - gutils->AtomicAdd = AtomicAdd; - const SmallPtrSet guaranteedUnreachable = - getGuaranteedUnreachable(gutils->oldFunc); - - // Convert uncacheable args from the input function to the preprocessed - // function - std::vector _overwritten_argsPP = _overwritten_args; - - gutils->forceActiveDetection(); - gutils->computeGuaranteedFrees(); - - CacheAnalysis CA(gutils->allocationsWithGuaranteedFree, - gutils->rematerializableAllocations, gutils->TR, - *gutils->OrigAA, gutils->oldFunc, - PPC.FAM.getResult(*gutils->oldFunc), - *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, - subsequent_calls_may_write, _overwritten_argsPP, - DerivativeMode::ReverseModePrimal, omp); - const std::map>> - overwritten_args_map = CA.compute_overwritten_args_for_callsites(); - gutils->overwritten_args_map_ptr = &overwritten_args_map; - - const std::map can_modref_map = - CA.compute_uncacheable_load_map(); - gutils->can_modref_map = &can_modref_map; - - // requires is_value_needed_in_reverse, that needs unnecessaryValues - // sets knownRecomputeHeuristic - gutils->computeMinCache(); - - // Requires knownRecomputeCache to be set as call to getContext - // itself calls createCacheForScope - gutils->forceAugmentedReturns(); - - SmallPtrSet unnecessaryValues; - SmallPtrSet unnecessaryInstructions; - calculateUnusedValuesInFunction(*gutils->oldFunc, unnecessaryValues, - unnecessaryInstructions, returnUsed, - DerivativeMode::ReverseModePrimal, gutils, - TLI, constant_args, guaranteedUnreachable); - gutils->unnecessaryValuesP = &unnecessaryValues; - - SmallPtrSet unnecessaryStores; - calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, - unnecessaryInstructions, gutils, TLI); - - insert_or_assign(AugmentedCachedFunctions, tup, - AugmentedReturn(gutils->newFunc, nullptr, {}, returnMapping, - overwritten_args_map, can_modref_map, - constant_args, shadowReturnUsed)); - - auto getIndex = [&](Instruction *I, CacheType u, IRBuilder<> &B) -> unsigned { - return gutils->getIndex( - std::make_pair(I, u), - AugmentedCachedFunctions.find(tup)->second.tapeIndices, B); - }; - - //! Explicitly handle all returns first to ensure that all instructions know - //! whether or not they are used - SmallPtrSet returnuses; - - for (BasicBlock &BB : *gutils->oldFunc) { - if (auto orig_ri = dyn_cast(BB.getTerminator())) { - auto ri = gutils->getNewFromOriginal(orig_ri); - Value *orig_oldval = orig_ri->getReturnValue(); - Value *oldval = - orig_oldval ? gutils->getNewFromOriginal(orig_oldval) : nullptr; - IRBuilder<> ib(ri); - Value *rt = UndefValue::get(gutils->newFunc->getReturnType()); - if (oldval && returnUsed) { - assert(returnMapping.find(AugmentedStruct::Return) != - returnMapping.end()); - auto idx = returnMapping.find(AugmentedStruct::Return)->second; - if (idx < 0) - rt = oldval; - else - rt = ib.CreateInsertValue(rt, oldval, {(unsigned)idx}); - if (Instruction *inst = dyn_cast(rt)) { - returnuses.insert(inst); - } - } - - auto newri = ib.CreateRet(rt); - gutils->originalToNewFn[orig_ri] = newri; - gutils->newToOriginalFn.erase(ri); - gutils->newToOriginalFn[newri] = orig_ri; - gutils->erase(ri); - } - } - - AdjointGenerator maker(DerivativeMode::ReverseModePrimal, gutils, - constant_args, retType, getIndex, overwritten_args_map, - &AugmentedCachedFunctions.find(tup)->second, nullptr, - unnecessaryValues, unnecessaryInstructions, - unnecessaryStores, guaranteedUnreachable); - - for (BasicBlock &oBB : *gutils->oldFunc) { - auto term = oBB.getTerminator(); - assert(term); - - // Don't create derivatives for code that results in termination - if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) { - SmallVector toerase; - - // For having the prints still exist on bugs, check if indeed unused - for (auto &I : oBB) { - toerase.push_back(&I); - } - for (auto I : toerase) { - maker.eraseIfUnused(*I, /*erase*/ true, /*check*/ true); - } - auto newBB = cast(gutils->getNewFromOriginal(&oBB)); - if (!newBB->getTerminator()) { - for (auto next : successors(&oBB)) { - auto sucBB = cast(gutils->getNewFromOriginal(next)); - sucBB->removePredecessor(newBB, /*KeepOneInputPHIs*/ true); - } - IRBuilder<> builder(newBB); - builder.CreateUnreachable(); - } - continue; - } - - if (!isa(term) && !isa(term) && - !isa(term)) { - llvm::errs() << *oBB.getParent() << "\n"; - llvm::errs() << "unknown terminator instance " << *term << "\n"; - assert(0 && "unknown terminator inst"); - llvm_unreachable("unknown terminator inst"); - } - - BasicBlock::reverse_iterator I = oBB.rbegin(), E = oBB.rend(); - ++I; - for (; I != E; ++I) { - maker.visit(&*I); - assert(oBB.rend() == E); - } - } - - if (gutils->knownRecomputeHeuristic.size()) { - // Even though we could simply iterate through the heuristic map, - // we explicity iterate in order of the instructions to maintain - // a deterministic cache ordering. - for (auto &BB : *gutils->oldFunc) - for (auto &I : BB) { - auto found = gutils->knownRecomputeHeuristic.find(&I); - if (found != gutils->knownRecomputeHeuristic.end()) { - if (!found->second && !isa(&I)) { - auto newi = gutils->getNewFromOriginal(&I); - IRBuilder<> BuilderZ(cast(newi)->getNextNode()); - if (isa(newi)) { - BuilderZ.SetInsertPoint( - cast(newi)->getParent()->getFirstNonPHI()); - } - gutils->cacheForReverse(BuilderZ, newi, - getIndex(&I, CacheType::Self, BuilderZ)); - } - } - } - } - - auto nf = gutils->newFunc; - - while (gutils->inversionAllocs->size() > 0) { - gutils->inversionAllocs->back().moveBefore( - gutils->newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime()); - } - - //! Keep track of inverted pointers we may need to return - ValueToValueMapTy invertedRetPs; - if (shadowReturnUsed) { - for (BasicBlock &BB : *gutils->oldFunc) { - if (auto ri = dyn_cast(BB.getTerminator())) { - if (Value *orig_oldval = ri->getReturnValue()) { - auto newri = gutils->getNewFromOriginal(ri); - IRBuilder<> BuilderZ(newri); - Value *invertri = nullptr; - if (gutils->isConstantValue(orig_oldval)) { - if (!gutils->runtimeActivity && - gutils->TR.query(orig_oldval)[{-1}].isPossiblePointer()) { - if (!isa(orig_oldval) && - !isa(orig_oldval)) { - std::string str; - raw_string_ostream ss(str); - ss << "Mismatched activity for: " << *ri - << " const val: " << *orig_oldval; - if (CustomErrorHandler) - invertri = unwrap(CustomErrorHandler( - str.c_str(), wrap(ri), ErrorType::MixedActivityError, - gutils, wrap(orig_oldval), wrap(&BuilderZ))); - else - EmitWarning("MixedActivityError", *ri, ss.str()); - } - } - } - if (!invertri) - invertri = gutils->invertPointerM(orig_oldval, BuilderZ, - /*nullShadow*/ true); - invertedRetPs[newri] = invertri; - } - } - } - } - - (IRBuilder<>(gutils->inversionAllocs)).CreateUnreachable(); - DeleteDeadBlock(gutils->inversionAllocs); - - for (Argument &Arg : gutils->newFunc->args()) { - if (Arg.hasAttribute(Attribute::Returned)) - Arg.removeAttr(Attribute::Returned); - if (Arg.hasAttribute(Attribute::StructRet)) - Arg.removeAttr(Attribute::StructRet); - } - - if (gutils->newFunc->hasFnAttribute(Attribute::OptimizeNone)) - gutils->newFunc->removeFnAttr(Attribute::OptimizeNone); - - if (gutils->newFunc->getAttributes().getRetDereferenceableBytes()) { - gutils->newFunc->removeRetAttr(Attribute::Dereferenceable); - } - - // TODO could keep nonnull if returning value -1 - if (gutils->newFunc->getAttributes().getRetAlignment()) { - gutils->newFunc->removeRetAttr(Attribute::Alignment); - } - - llvm::Attribute::AttrKind attrs[] = { -#if LLVM_VERSION_MAJOR >= 19 - llvm::Attribute::Range, -#endif -#if LLVM_VERSION_MAJOR >= 17 - llvm::Attribute::NoFPClass, -#endif - llvm::Attribute::NoAlias, - llvm::Attribute::NoUndef, - llvm::Attribute::NonNull, - llvm::Attribute::ZExt, - llvm::Attribute::SExt, - }; - for (auto attr : attrs) { - if (gutils->newFunc->hasRetAttribute(attr)) { - gutils->newFunc->removeRetAttr(attr); - } - } - for (auto attr : {"enzyme_inactive", "enzyme_type"}) { - if (gutils->newFunc->getAttributes().hasRetAttr(attr)) { - gutils->newFunc->removeRetAttr(attr); - } - } - - gutils->eraseFictiousPHIs(); - - if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - report_fatal_error("function failed verification (2)"); - } - - SmallVector MallocTypes; - - bool nonRecursiveUse = false; - - for (auto a : gutils->getTapeValues()) { - MallocTypes.push_back(a->getType()); - if (auto ei = dyn_cast(a)) { - auto tidx = returnMapping.find(AugmentedStruct::Tape)->second; - if (ei->getIndices().size() == 1 && ei->getIndices()[0] == (unsigned)tidx) - if (auto cb = dyn_cast(ei->getOperand(0))) - if (gutils->newFunc == cb->getCalledFunction()) - continue; - } - nonRecursiveUse = true; - } - if (MallocTypes.size() == 0) - nonRecursiveUse = true; - if (!nonRecursiveUse) - MallocTypes.clear(); - - Type *tapeType = StructType::get(nf->getContext(), MallocTypes); - - bool removeTapeStruct = MallocTypes.size() == 1; - if (removeTapeStruct) { - tapeType = MallocTypes[0]; - - for (auto &a : AugmentedCachedFunctions.find(tup)->second.tapeIndices) { - a.second = -1; - } - } - - bool recursive = - AugmentedCachedFunctions.find(tup)->second.fn->getNumUses() > 0 || - forceAnonymousTape; - bool noTape = MallocTypes.size() == 0 && !forceAnonymousTape; - - StructType *sty = cast(gutils->newFunc->getReturnType()); - SmallVector RetTypes(sty->elements().begin(), - sty->elements().end()); - if (!noTape) { - if (recursive && !omp) { - auto size = - gutils->newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits( - tapeType); - if (size != 0) { - RetTypes[returnMapping.find(AugmentedStruct::Tape)->second] = - getDefaultAnonymousTapeType(gutils->newFunc->getContext()); - } - } - } - - int oldretIdx = -1; - if (returnMapping.find(AugmentedStruct::Return) != returnMapping.end()) { - oldretIdx = returnMapping[AugmentedStruct::Return]; - } - - if (noTape || omp) { - auto tidx = returnMapping.find(AugmentedStruct::Tape)->second; - if (noTape) - returnMapping.erase(AugmentedStruct::Tape); - if (noTape) - AugmentedCachedFunctions.find(tup)->second.returns.erase( - AugmentedStruct::Tape); - if (returnMapping.find(AugmentedStruct::Return) != returnMapping.end()) { - AugmentedCachedFunctions.find(tup) - ->second.returns[AugmentedStruct::Return] -= - (returnMapping[AugmentedStruct::Return] > tidx) ? 1 : 0; - returnMapping[AugmentedStruct::Return] -= - (returnMapping[AugmentedStruct::Return] > tidx) ? 1 : 0; - } - if (returnMapping.find(AugmentedStruct::DifferentialReturn) != - returnMapping.end()) { - AugmentedCachedFunctions.find(tup) - ->second.returns[AugmentedStruct::DifferentialReturn] -= - (returnMapping[AugmentedStruct::DifferentialReturn] > tidx) ? 1 : 0; - returnMapping[AugmentedStruct::DifferentialReturn] -= - (returnMapping[AugmentedStruct::DifferentialReturn] > tidx) ? 1 : 0; - } - RetTypes.erase(RetTypes.begin() + tidx); - } else if (recursive) { - } else { - RetTypes[returnMapping.find(AugmentedStruct::Tape)->second] = tapeType; - } - - bool noReturn = RetTypes.size() == 0; - Type *RetType = StructType::get(nf->getContext(), RetTypes); - if (noReturn) - RetType = Type::getVoidTy(RetType->getContext()); - if (noReturn) - assert(noTape || omp); - - bool removeStruct = RetTypes.size() == 1; - - if (removeStruct) { - RetType = RetTypes[0]; - for (auto &a : returnMapping) { - a.second = -1; - } - for (auto &a : AugmentedCachedFunctions.find(tup)->second.returns) { - a.second = -1; - } - } - - ValueToValueMapTy VMap; - SmallVector ArgTypes; - for (const Argument &I : nf->args()) { - ArgTypes.push_back(I.getType()); - } - - if (omp && !noTape) { - // see lack of struct type for openmp - ArgTypes.push_back(PointerType::getUnqual(tapeType)); - // ArgTypes.push_back(tapeType); - } - - // Create a new function type... - FunctionType *FTy = - FunctionType::get(RetType, ArgTypes, nf->getFunctionType()->isVarArg()); - - // Create the new function... - Function *NewF = Function::Create( - FTy, nf->getLinkage(), "augmented_" + todiff->getName(), nf->getParent()); - - unsigned attrIndex = 0; - auto i = nf->arg_begin(), j = NewF->arg_begin(); - while (i != nf->arg_end()) { - VMap[i] = j; -#if LLVM_VERSION_MAJOR > 20 - if (nf->hasParamAttribute(attrIndex, Attribute::Captures)) { - NewF->addParamAttr(attrIndex, - nf->getParamAttribute(attrIndex, Attribute::Captures)); - } -#else - if (nf->hasParamAttribute(attrIndex, Attribute::NoCapture)) { - NewF->addParamAttr( - attrIndex, nf->getParamAttribute(attrIndex, Attribute::NoCapture)); - } -#endif - if (nf->hasParamAttribute(attrIndex, Attribute::NoAlias)) { - NewF->addParamAttr(attrIndex, Attribute::NoAlias); - } - for (auto name : {"enzyme_sret", "enzyme_sret_v", "enzymejl_returnRoots", - "enzymejl_returnRoots_v", "enzymejl_parmtype", - "enzymejl_parmtype_ref", "enzyme_type"}) - if (nf->getAttributes().hasParamAttr(attrIndex, name)) { - NewF->addParamAttr(attrIndex, - nf->getAttributes().getParamAttr(attrIndex, name)); - } - - j->setName(i->getName()); - ++j; - ++i; - ++attrIndex; - } - - for (auto attr : {"enzyme_ta_norecur"}) - if (nf->getAttributes().hasAttributeAtIndex(AttributeList::FunctionIndex, - attr)) { - NewF->addFnAttr( - nf->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); - } - - for (auto attr : - {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) - if (nf->getAttributes().hasAttributeAtIndex(AttributeList::ReturnIndex, - attr)) { - NewF->addAttribute( - AttributeList::ReturnIndex, - nf->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); - } - - SmallVector Returns; - CloneFunctionInto(NewF, nf, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); - - IRBuilder<> ib(NewF->getEntryBlock().getFirstNonPHI()); - - AllocaInst *ret = noReturn ? nullptr : ib.CreateAlloca(RetType); - - if (!noTape) { - Value *tapeMemory; - if (recursive && !omp) { - auto i64 = Type::getInt64Ty(NewF->getContext()); - auto size = - NewF->getParent()->getDataLayout().getTypeAllocSizeInBits(tapeType); - Value *memory; - if (size != 0) { - CallInst *malloccall = nullptr; - Instruction *zero = nullptr; - tapeMemory = CreateAllocation( - ib, tapeType, ConstantInt::get(i64, 1), "tapemem", &malloccall, - EnzymeZeroCache ? &zero : nullptr, /*isDefault*/ true); - memory = malloccall; - } else { - memory = ConstantPointerNull::get( - getDefaultAnonymousTapeType(NewF->getContext())); - } - Value *Idxs[] = { - ib.getInt32(0), - ib.getInt32(returnMapping.find(AugmentedStruct::Tape)->second), - }; - assert(memory); - assert(ret); - Value *gep = ret; - if (!removeStruct) { - gep = ib.CreateGEP(RetType, ret, Idxs, ""); - cast(gep)->setIsInBounds(true); - } - auto storeinst = ib.CreateStore(memory, gep); - PostCacheStore(storeinst, ib); - } else if (omp) { - j->setName("tape"); - tapeMemory = j; - // if structs were supported by openmp we could do this, but alas, no - // IRBuilder<> B(NewF->getEntryBlock().getFirstNonPHI()); - // tapeMemory = B.CreateAlloca(j->getType()); - // B.CreateStore(j, tapeMemory); - } else { - Value *Idxs[] = { - ib.getInt32(0), - ib.getInt32(returnMapping.find(AugmentedStruct::Tape)->second), - }; - tapeMemory = ret; - if (!removeStruct) { - tapeMemory = ib.CreateGEP(RetType, ret, Idxs, ""); - cast(tapeMemory)->setIsInBounds(true); - } - if (EnzymeZeroCache) { - ZeroMemory(ib, tapeType, tapeMemory, - /*isTape*/ true); - } - } - - unsigned i = 0; - for (auto v : gutils->getTapeValues()) { - if (!isa(v)) { - if (!isa(VMap[v])) { - llvm::errs() << " non constant for vmap[v=" << *v - << " ]= " << *VMap[v] << "\n"; - } - auto inst = cast(VMap[v]); - IRBuilder<> ib(inst->getNextNode()); - if (isa(inst)) - ib.SetInsertPoint(inst->getParent()->getFirstNonPHI()); - Value *Idxs[] = {ib.getInt32(0), ib.getInt32(i)}; - Value *gep = tapeMemory; - if (!removeTapeStruct) { - gep = ib.CreateGEP(tapeType, tapeMemory, Idxs, ""); - cast(gep)->setIsInBounds(true); - } - auto storeinst = ib.CreateStore(VMap[v], gep); - PostCacheStore(storeinst, ib); - } - ++i; - } - } else if (!nonRecursiveUse) { - for (auto v : gutils->getTapeValues()) { - if (isa(v)) - continue; - auto EV = cast(v); - auto EV2 = cast(VMap[v]); - assert(EV->use_empty()); - EV->eraseFromParent(); - assert(EV2->use_empty()); - EV2->eraseFromParent(); - } - } - - for (BasicBlock &BB : *nf) { - auto ri = dyn_cast(BB.getTerminator()); - if (ri == nullptr) - continue; - ReturnInst *rim = cast(VMap[ri]); - IRBuilder<> ib(rim); - if (returnUsed) { - Value *rv = rim->getReturnValue(); - assert(rv); - Value *actualrv = nullptr; - if (auto iv = dyn_cast(rv)) { - if (iv->getNumIndices() == 1 && (int)iv->getIndices()[0] == oldretIdx) { - actualrv = iv->getInsertedValueOperand(); - } - } - if (actualrv == nullptr) { - if (oldretIdx < 0) - actualrv = rv; - else - actualrv = ib.CreateExtractValue(rv, {(unsigned)oldretIdx}); - } - Value *gep = - removeStruct - ? ret - : ib.CreateConstGEP2_32( - RetType, ret, 0, - returnMapping.find(AugmentedStruct::Return)->second, ""); - if (auto ggep = dyn_cast(gep)) { - ggep->setIsInBounds(true); - } - if (EnzymeFixupReturn) - actualrv = unwrap(EnzymeFixupReturn(wrap(&ib), wrap(actualrv))); - auto storeinst = ib.CreateStore(actualrv, gep); - PostCacheStore(storeinst, ib); - } - - if (shadowReturnUsed) { - assert(invertedRetPs[ri]); - Value *shadowRV = invertedRetPs[ri]; - - if (!isa(shadowRV)) { - Value *gep = - removeStruct - ? ret - : ib.CreateConstGEP2_32( - RetType, ret, 0, - returnMapping.find(AugmentedStruct::DifferentialReturn) - ->second, - ""); - if (auto ggep = dyn_cast(gep)) { - ggep->setIsInBounds(true); - } - if (!(isa(shadowRV) || isa(shadowRV) || - isa(shadowRV) || - isa(shadowRV))) { - auto found = VMap.find(shadowRV); - assert(found != VMap.end()); - shadowRV = found->second; - } - if (EnzymeFixupReturn) - shadowRV = unwrap(EnzymeFixupReturn(wrap(&ib), wrap(shadowRV))); - auto storeinst = ib.CreateStore(shadowRV, gep); - PostCacheStore(storeinst, ib); - } - } - if (noReturn) - ib.CreateRetVoid(); - else { - ib.CreateRet(ib.CreateLoad(RetType, ret)); - } - cast(VMap[ri])->eraseFromParent(); - } - - clearFunctionAttributes(NewF); - PPC.LowerAllocAddr(NewF); - - if (llvm::verifyFunction(*NewF, &llvm::errs())) { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *NewF << "\n"; - report_fatal_error("augmented function failed verification (3)"); - } - { - PreservedAnalyses PA; - PPC.FAM.invalidate(*NewF, PA); - } - - SmallVector fnusers; - SmallVector, 1> gfnusers; - for (auto user : AugmentedCachedFunctions.find(tup)->second.fn->users()) { - if (auto CI = dyn_cast(user)) { - fnusers.push_back(CI); - } else { - if (auto CS = dyn_cast(user)) { - for (auto cuser : CS->users()) { - if (auto G = dyn_cast(cuser)) { - if (("_enzyme_reverse_" + todiff->getName() + "'").str() == - G->getName()) { - gfnusers.emplace_back(G, DerivativeMode::ReverseModeGradient); - continue; - } - if (("_enzyme_forwardsplit_" + todiff->getName() + "'").str() == - G->getName()) { - gfnusers.emplace_back(G, DerivativeMode::ForwardModeSplit); - continue; - } - } - llvm::errs() << *gutils->newFunc->getParent() << "\n"; - llvm::errs() << *cuser << "\n"; - llvm::errs() << *user << "\n"; - llvm_unreachable("Bad cuser of staging augmented forward fn"); - } - continue; - } - llvm::errs() << *gutils->newFunc->getParent() << "\n"; - llvm::errs() << *user << "\n"; - llvm_unreachable("Bad user of staging augmented forward fn"); - } - } - for (auto user : fnusers) { - if (removeStruct || !nonRecursiveUse) { - IRBuilder<> B(user); - SmallVector args(user->arg_begin(), user->arg_end()); - auto rep = B.CreateCall(NewF, args); - if (!rep->getType()->isVoidTy()) - rep->takeName(user); - rep->copyIRFlags(user); - rep->setAttributes(user->getAttributes()); - rep->setCallingConv(user->getCallingConv()); - rep->setTailCallKind(user->getTailCallKind()); - rep->setDebugLoc(gutils->getNewFromOriginal(user->getDebugLoc())); - assert(user); - SmallVector torep; - for (auto u : user->users()) { - assert(u); - if (auto ei = dyn_cast(u)) { - torep.push_back(ei); - } - } - for (auto ei : torep) { - ei->replaceAllUsesWith(rep); - ei->eraseFromParent(); - } - if (user->getParent()->getParent() == gutils->newFunc) - gutils->erase(user); - else - user->eraseFromParent(); - } else { - user->setCalledFunction(NewF); - } - } - PPC.AlwaysInline(NewF); - auto Arch = llvm::Triple(NewF->getParent()->getTargetTriple()).getArch(); - if (Arch == Triple::nvptx || Arch == Triple::nvptx64) - PPC.ReplaceReallocs(NewF, /*mem2reg*/ true); - - AugmentedCachedFunctions.find(tup)->second.fn = NewF; - if ((recursive && nonRecursiveUse) || (omp && !noTape)) - AugmentedCachedFunctions.find(tup)->second.tapeType = tapeType; - AugmentedCachedFunctions.find(tup)->second.isComplete = true; - - for (auto pair : gfnusers) { - auto GV = pair.first; - GV->setName("_tmp"); - auto R = gutils->GetOrCreateShadowFunction( - context, *this, TLI, TA, todiff, pair.second, gutils->runtimeActivity, - width, gutils->AtomicAdd); - SmallVector, 1> users; - GV->replaceAllUsesWith(ConstantExpr::getPointerCast(R, GV->getType())); - GV->eraseFromParent(); - } - - { - PreservedAnalyses PA; - PPC.FAM.invalidate(*gutils->newFunc, PA); - } - - Function *tempFunc = gutils->newFunc; - delete gutils; - tempFunc->eraseFromParent(); - - // Do not run post processing optimizations if the body of an openmp - // parallel so the adjointgenerator can successfully extract the allocation - // and frees and hoist them into the parent. Optimizing before then may - // make the IR different to traverse, and thus impossible to find the allocs. - if (PostOpt && !omp) - PPC.optimizeIntermediate(NewF); - if (EnzymePrint) - llvm::errs() << *NewF << "\n"; - return AugmentedCachedFunctions.find(tup)->second; -} - -void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB, - DIFFE_TYPE retType, ReturnType retVal) { - TypeResults &TR = gutils->TR; - ReturnInst *inst = dyn_cast(oBB->getTerminator()); - // In forward mode we only need to update the return value - if (inst == nullptr) - return; - - ReturnInst *newInst = cast(gutils->getNewFromOriginal(inst)); - BasicBlock *nBB = newInst->getParent(); - assert(nBB); - IRBuilder<> nBuilder(nBB); - nBuilder.setFastMathFlags(getFast()); - - SmallVector retargs; - - Value *toret = UndefValue::get(gutils->newFunc->getReturnType()); - - Value *invertedPtr = nullptr; - - if (retType != DIFFE_TYPE::CONSTANT) { - auto ret = inst->getOperand(0); - Type *rt = ret->getType(); - while (auto AT = dyn_cast(rt)) - rt = AT->getElementType(); - bool floatLike = rt->isFPOrFPVectorTy(); - if (!floatLike && TR.getReturnAnalysis().Inner0().isPossiblePointer()) { - if (gutils->isConstantValue(ret)) { - if (!gutils->runtimeActivity && - TR.query(ret)[{-1}].isPossiblePointer()) { - if (!isa(ret) && !isa(ret)) { - std::string str; - raw_string_ostream ss(str); - ss << "Mismatched activity for: " << *inst - << " const val: " << *ret; - if (CustomErrorHandler) - invertedPtr = unwrap(CustomErrorHandler( - str.c_str(), wrap(inst), ErrorType::MixedActivityError, - gutils, wrap(ret), wrap(&nBuilder))); - else - EmitWarning("MixedActivityError", *inst, ss.str()); - } - } - } - } - } - - switch (retVal) { - case ReturnType::Return: { - auto ret = inst->getOperand(0); - - Type *rt = ret->getType(); - while (auto AT = dyn_cast(rt)) - rt = AT->getElementType(); - bool floatLike = rt->isFPOrFPVectorTy(); - - if (retType == DIFFE_TYPE::CONSTANT) { - toret = gutils->getNewFromOriginal(ret); - } else if (!floatLike && - TR.getReturnAnalysis().Inner0().isPossiblePointer()) { - toret = invertedPtr ? invertedPtr : gutils->invertPointerM(ret, nBuilder); - } else if (!gutils->isConstantValue(ret)) { - assert(!invertedPtr); - toret = gutils->diffe(ret, nBuilder); - } else { - toret = invertedPtr - ? invertedPtr - : gutils->invertPointerM(ret, nBuilder, /*nullInit*/ true); - } - - break; - } - case ReturnType::TwoReturns: { - if (retType == DIFFE_TYPE::CONSTANT) - assert(false && "Invalid return type"); - auto ret = inst->getOperand(0); - - Type *rt = ret->getType(); - while (auto AT = dyn_cast(rt)) - rt = AT->getElementType(); - bool floatLike = rt->isFPOrFPVectorTy(); - - toret = - nBuilder.CreateInsertValue(toret, gutils->getNewFromOriginal(ret), 0); - - if (!floatLike && TR.getReturnAnalysis().Inner0().isPossiblePointer()) { - toret = nBuilder.CreateInsertValue( - toret, - invertedPtr ? invertedPtr : gutils->invertPointerM(ret, nBuilder), 1); - } else if (!gutils->isConstantValue(ret)) { - assert(!invertedPtr); - toret = - nBuilder.CreateInsertValue(toret, gutils->diffe(ret, nBuilder), 1); - } else { - toret = nBuilder.CreateInsertValue( - toret, - invertedPtr - ? invertedPtr - : gutils->invertPointerM(ret, nBuilder, /*nullInit*/ true), - 1); - } - break; - } - case ReturnType::Void: { - gutils->erase(gutils->getNewFromOriginal(inst)); - nBuilder.CreateRetVoid(); - return; - } - default: { - llvm::errs() << "Invalid return type: " << to_string(retVal) - << "for function: \n" - << gutils->newFunc << "\n"; - assert(false && "Invalid return type for function"); - return; - } - } - - gutils->erase(newInst); - nBuilder.CreateRet(toret); - return; -} - -Value *selectByWidth(IRBuilder<> &B, DiffeGradientUtils *gutils, Value *cond, - Value *tval, Value *fval) { - auto width = gutils->getWidth(); - if (width == 1) { - return B.CreateSelect(cond, tval, fval); - } - Value *res = UndefValue::get(tval->getType()); - - for (unsigned int i = 0; i < width; ++i) { - auto ntval = GradientUtils::extractMeta(B, tval, i); - auto nfval = GradientUtils::extractMeta(B, fval, i); - res = B.CreateInsertValue(res, B.CreateSelect(cond, ntval, nfval), {i}); - } - return res; -} - -void createInvertedTerminator(DiffeGradientUtils *gutils, - ArrayRef argTypes, BasicBlock *oBB, - AllocaInst *retAlloca, AllocaInst *dretAlloca, - unsigned extraArgs, DIFFE_TYPE retType) { - LoopContext loopContext; - BasicBlock *BB = cast(gutils->getNewFromOriginal(oBB)); - bool inLoop = gutils->getContext(BB, loopContext); - BasicBlock *BB2 = gutils->reverseBlocks[BB].back(); - assert(BB2); - IRBuilder<> Builder(BB2); - Builder.setFastMathFlags(getFast()); - - std::map> targetToPreds; - for (auto pred : predecessors(BB)) { - targetToPreds[gutils->getReverseOrLatchMerge(pred, BB)].push_back(pred); - } - - if (targetToPreds.size() == 0) { - SmallVector retargs; - - if (retAlloca) { - auto result = Builder.CreateLoad(retAlloca->getAllocatedType(), retAlloca, - "retreload"); - // TODO reintroduce invariant load/group - // result->setMetadata(LLVMContext::MD_invariant_load, - // MDNode::get(retAlloca->getContext(), {})); - retargs.push_back(result); - } - - if (dretAlloca) { - auto result = Builder.CreateLoad(dretAlloca->getAllocatedType(), - dretAlloca, "dretreload"); - // TODO reintroduce invariant load/group - // result->setMetadata(LLVMContext::MD_invariant_load, - // MDNode::get(dretAlloca->getContext(), {})); - retargs.push_back(result); - } - - for (auto &I : gutils->oldFunc->args()) { - if (!gutils->isConstantValue(&I) && - argTypes[I.getArgNo()] == DIFFE_TYPE::OUT_DIFF) { - retargs.push_back(gutils->diffe(&I, Builder)); - } - } - - if (gutils->newFunc->getReturnType()->isVoidTy()) { - assert(retargs.size() == 0); - Builder.CreateRetVoid(); - return; - } - - Value *toret = UndefValue::get(gutils->newFunc->getReturnType()); - for (unsigned i = 0; i < retargs.size(); ++i) { - unsigned idx[] = {i}; - toret = Builder.CreateInsertValue(toret, retargs[i], idx); - } - Builder.CreateRet(toret); - return; - } - - // PHINodes to replace that will contain true iff the predecessor was given - // basicblock - std::map replacePHIs; - SmallVector selects; - - IRBuilder<> phibuilder(BB2); - bool setphi = false; - - // We force loads of all phi nodes at once, to ensure correct results in the - // case that one phi node depends on others. - SmallVector, 1> phis; - - // Ensure phi values have their derivatives propagated - for (auto I = oBB->begin(), E = oBB->end(); I != E; ++I) { - PHINode *orig = dyn_cast(&*I); - if (orig == nullptr) - break; - if (gutils->isConstantInstruction(orig)) - continue; - - size_t size = 1; - if (orig->getType()->isSized()) - size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - orig->getType()) + - 7) / - 8; - - auto PNtypeT = gutils->TR.query(orig); - auto PNtype = PNtypeT[{-1}]; - - // TODO remove explicit type check and only use PNtype - if (!gutils->TR.anyFloat(orig, /*anythingIsFloat*/ false) || - orig->getType()->isPointerTy()) - continue; - - Type *PNfloatType = PNtype.isFloat(); - if (!PNfloatType) { - // Try to use the 0-th elem for all elems - PNtype = PNtypeT[{0}]; - bool legal = true; - for (size_t i = 1; i < size; i++) { - if (!PNtypeT[{(int)i}].isFloat()) - continue; - PNtype.checkedOrIn(PNtypeT[{(int)i}], /*pointerIntSame*/ true, legal); - if (!legal) { - break; - } - } - if (legal) { - PNfloatType = PNtype.isFloat(); - if (!PNfloatType) { - if (looseTypeAnalysis) { - if (orig->getType()->isFPOrFPVectorTy()) - PNfloatType = orig->getType()->getScalarType(); - if (orig->getType()->isIntOrIntVectorTy()) - continue; - } - } - } - } - if (!PNfloatType) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of phi " << *orig << PNtypeT.str() - << " sz: " << size << "\n"; - EmitNoTypeError(ss.str(), *orig, gutils, Builder); - continue; - } - - auto prediff = gutils->diffe(orig, Builder); - bool handled = false; - - SmallVector activeUses; - for (auto u : orig->users()) { - if (!gutils->isConstantInstruction(cast(u))) - activeUses.push_back(cast(u)); - else if (retType == DIFFE_TYPE::OUT_DIFF && isa(u)) - activeUses.push_back(cast(u)); - } - if (activeUses.size() == 1 && inLoop && - gutils->getNewFromOriginal(orig->getParent()) == loopContext.header && - loopContext.exitBlocks.size() == 1) { - SmallVector Latches; - gutils->OrigLI->getLoopFor(orig->getParent())->getLoopLatches(Latches); - bool allIncoming = true; - for (auto Latch : Latches) { - if (activeUses[0] != orig->getIncomingValueForBlock(Latch)) { - allIncoming = false; - break; - } - } - if (allIncoming) { - if (auto SI = dyn_cast(activeUses[0])) { - for (int i = 0; i < 2; i++) { - if (SI->getOperand(i + 1) == orig) { - auto oval = orig->getIncomingValueForBlock( - gutils->getOriginalFromNew(loopContext.preheader)); - BasicBlock *pred = loopContext.preheader; - if (replacePHIs.find(pred) == replacePHIs.end()) { - replacePHIs[pred] = Builder.CreatePHI( - Type::getInt1Ty(pred->getContext()), 1, "replacePHI"); - if (!setphi) { - phibuilder.SetInsertPoint(replacePHIs[pred]); - setphi = true; - } - } - - auto ddiff = gutils->diffe(SI, Builder); - gutils->setDiffe( - SI, - selectByWidth(Builder, gutils, replacePHIs[pred], - Constant::getNullValue(prediff->getType()), - ddiff), - Builder); - handled = true; - if (!gutils->isConstantValue(oval)) { - BasicBlock *REB = - gutils->reverseBlocks[*loopContext.exitBlocks.begin()] - .back(); - IRBuilder<> EB(REB); - if (REB->getTerminator()) - EB.SetInsertPoint(REB->getTerminator()); - - auto index = gutils->getOrInsertConditionalIndex( - gutils->getNewFromOriginal(SI->getOperand(0)), loopContext, - i == 1); - Value *sdif = selectByWidth( - Builder, gutils, - Builder.CreateICmpEQ( - gutils->lookupM(index, EB), - Constant::getNullValue(index->getType())), - ddiff, Constant::getNullValue(ddiff->getType())); - - auto dif = - selectByWidth(Builder, gutils, replacePHIs[pred], sdif, - Constant::getNullValue(prediff->getType())); - auto addedSelects = - gutils->addToDiffe(oval, dif, Builder, PNfloatType); - - for (auto select : addedSelects) - selects.push_back(select); - } - break; - } - } - } - if (auto BO = dyn_cast(activeUses[0])) { - - if (BO->getOpcode() == Instruction::FDiv && - BO->getOperand(0) == orig) { - - auto oval = orig->getIncomingValueForBlock( - gutils->getOriginalFromNew(loopContext.preheader)); - BasicBlock *pred = loopContext.preheader; - if (replacePHIs.find(pred) == replacePHIs.end()) { - replacePHIs[pred] = Builder.CreatePHI( - Type::getInt1Ty(pred->getContext()), 1, "replacePHI"); - if (!setphi) { - phibuilder.SetInsertPoint(replacePHIs[pred]); - setphi = true; - } - } - - auto ddiff = gutils->diffe(BO, Builder); - gutils->setDiffe( - BO, - selectByWidth(Builder, gutils, replacePHIs[pred], - Constant::getNullValue(prediff->getType()), - ddiff), - Builder); - handled = true; - - if (!gutils->isConstantValue(oval)) { - - BasicBlock *REB = - gutils->reverseBlocks[*loopContext.exitBlocks.begin()].back(); - IRBuilder<> EB(REB); - if (REB->getTerminator()) - EB.SetInsertPoint(REB->getTerminator()); - - auto product = gutils->getOrInsertTotalMultiplicativeProduct( - gutils->getNewFromOriginal(BO->getOperand(1)), loopContext); - - auto dif = selectByWidth( - Builder, gutils, replacePHIs[pred], - Builder.CreateFDiv(ddiff, gutils->lookupM(product, EB)), - Constant::getNullValue(prediff->getType())); - auto addedSelects = - gutils->addToDiffe(oval, dif, Builder, PNfloatType); - - for (auto select : addedSelects) - selects.push_back(select); - } - } - } - } - } - if (!handled) { - gutils->setDiffe( - orig, Constant::getNullValue(gutils->getShadowType(orig->getType())), - Builder); - phis.emplace_back(orig, prediff, PNfloatType); - } - } - - for (auto &&[orig, prediff, PNfloatType] : phis) { - - for (BasicBlock *opred : predecessors(oBB)) { - auto oval = orig->getIncomingValueForBlock(opred); - if (gutils->isConstantValue(oval)) { - continue; - } - - if (orig->getNumIncomingValues() == 1) { - gutils->addToDiffe(oval, prediff, Builder, PNfloatType); - } else { - BasicBlock *pred = cast(gutils->getNewFromOriginal(opred)); - if (replacePHIs.find(pred) == replacePHIs.end()) { - replacePHIs[pred] = Builder.CreatePHI( - Type::getInt1Ty(pred->getContext()), 1, "replacePHI"); - if (!setphi) { - phibuilder.SetInsertPoint(replacePHIs[pred]); - setphi = true; - } - } - auto dif = selectByWidth(Builder, gutils, replacePHIs[pred], prediff, - Constant::getNullValue(prediff->getType())); - auto addedSelects = gutils->addToDiffe(oval, dif, Builder, PNfloatType); - - for (auto select : addedSelects) - selects.push_back(select); - } - } - } - if (!setphi) { - phibuilder.SetInsertPoint(Builder.GetInsertBlock(), - Builder.GetInsertPoint()); - } - - if (inLoop && BB == loopContext.header) { - std::map> targetToPreds; - for (auto pred : predecessors(BB)) { - if (pred == loopContext.preheader) - continue; - targetToPreds[gutils->getReverseOrLatchMerge(pred, BB)].push_back(pred); - } - - assert(targetToPreds.size() && - "only loops with one backedge are presently supported"); - - Value *av = phibuilder.CreateLoad(loopContext.var->getType(), - loopContext.antivaralloc); - Value *phi = - phibuilder.CreateICmpEQ(av, Constant::getNullValue(av->getType())); - Value *nphi = phibuilder.CreateNot(phi); - - for (auto pair : replacePHIs) { - Value *replaceWith = nullptr; - - if (pair.first == loopContext.preheader) { - replaceWith = phi; - } else { - replaceWith = nphi; - } - - pair.second->replaceAllUsesWith(replaceWith); - pair.second->eraseFromParent(); - } - BB2 = gutils->reverseBlocks[BB].back(); - Builder.SetInsertPoint(BB2); - - Builder.CreateCondBr( - phi, gutils->getReverseOrLatchMerge(loopContext.preheader, BB), - targetToPreds.begin()->first); - - } else { - std::map>> - phiTargetToPreds; - for (auto pair : replacePHIs) { - phiTargetToPreds[pair.first].emplace_back(pair.first, BB); - } - BasicBlock *fakeTarget = nullptr; - for (auto pred : predecessors(BB)) { - if (phiTargetToPreds.find(pred) != phiTargetToPreds.end()) - continue; - if (fakeTarget == nullptr) - fakeTarget = pred; - phiTargetToPreds[fakeTarget].emplace_back(pred, BB); - } - gutils->branchToCorrespondingTarget(BB, phibuilder, phiTargetToPreds, - &replacePHIs); - - std::map>> - targetToPreds; - for (auto pred : predecessors(BB)) { - targetToPreds[gutils->getReverseOrLatchMerge(pred, BB)].emplace_back(pred, - BB); - } - BB2 = gutils->reverseBlocks[BB].back(); - Builder.SetInsertPoint(BB2); - gutils->branchToCorrespondingTarget(BB, Builder, targetToPreds); - } - - // Optimize select of not to just be a select with operands switched - for (SelectInst *select : selects) { - if (BinaryOperator *bo = dyn_cast(select->getCondition())) { - if (bo->getOpcode() == BinaryOperator::Xor) { - if (isa(bo->getOperand(0)) && - cast(bo->getOperand(0))->isOne()) { - select->setCondition(bo->getOperand(1)); - auto tmp = select->getTrueValue(); - select->setTrueValue(select->getFalseValue()); - select->setFalseValue(tmp); - if (bo->getNumUses() == 0) - bo->eraseFromParent(); - } else if (isa(bo->getOperand(1)) && - cast(bo->getOperand(1))->isOne()) { - select->setCondition(bo->getOperand(0)); - auto tmp = select->getTrueValue(); - select->setTrueValue(select->getFalseValue()); - select->setFalseValue(tmp); - if (bo->getNumUses() == 0) - bo->eraseFromParent(); - } - } - } - } -} - -Function *EnzymeLogic::CreatePrimalAndGradient( - RequestContext context, const ReverseCacheKey &&prevkey, TypeAnalysis &TA, - const AugmentedReturn *augmenteddata, bool omp) { - - TimeTraceScope timeScope("CreatePrimalAndGradient", - prevkey.todiff->getName()); - - assert(prevkey.mode == DerivativeMode::ReverseModeCombined || - prevkey.mode == DerivativeMode::ReverseModeGradient); - - FnTypeInfo oldTypeInfo = - preventTypeAnalysisLoops(prevkey.typeInfo, prevkey.todiff); - auto key = prevkey.replaceTypeInfo(oldTypeInfo); - - if (key.retType != DIFFE_TYPE::CONSTANT) - assert(!key.todiff->getReturnType()->isVoidTy()); - - if (!isMemFreeLibMFunction(getFuncName(key.todiff))) - assert(key.overwritten_args.size() == key.todiff->arg_size()); - - Function *prevFunction = nullptr; - if (ReverseCachedFunctions.find(key) != ReverseCachedFunctions.end()) { - prevFunction = ReverseCachedFunctions.find(key)->second; - if (!hasMetadata(prevFunction, "enzyme_placeholder")) - return prevFunction; - if (augmenteddata && !augmenteddata->isComplete) - return prevFunction; - } - - if (key.returnUsed) - assert(key.mode == DerivativeMode::ReverseModeCombined); - - TargetLibraryInfo &TLI = - PPC.FAM.getResult(*key.todiff); - - // TODO change this to go by default function type assumptions - bool hasconstant = false; - for (auto v : key.constant_args) { - if (v == DIFFE_TYPE::CONSTANT) { - hasconstant = true; - break; - } - } - - if (hasMetadata(key.todiff, "enzyme_gradient")) { - std::set seen; - DIFFE_TYPE subretType = whatType(key.todiff->getReturnType(), - DerivativeMode::ReverseModeGradient, - /*intAreConstant*/ false, seen); - if (key.todiff->getReturnType()->isVoidTy() || - key.todiff->getReturnType()->isEmptyTy()) - subretType = DIFFE_TYPE::CONSTANT; - if (subretType != key.retType) { - std::string str; - raw_string_ostream ss(str); - ss << "The required return activity calling into function: " - << key.todiff->getName() << " was " << to_string(key.retType) - << " but the assumed (default) return activity was " - << to_string(subretType) << "\n"; - if (context.req) { - ss << " at context: " << *context.req; - } - if (EmitNoDerivativeError(ss.str(), key.todiff, context)) { - return nullptr; - } - } - - if (key.mode == DerivativeMode::ReverseModeCombined) { - auto res = getDefaultFunctionTypeForGradient( - key.todiff->getFunctionType(), - /*retType*/ key.retType, key.constant_args); - - Type *FRetTy = - res.second.empty() - ? Type::getVoidTy(key.todiff->getContext()) - : StructType::get(key.todiff->getContext(), {res.second}); - - FunctionType *FTy = FunctionType::get( - FRetTy, res.first, key.todiff->getFunctionType()->isVarArg()); - - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixgradient_" + key.todiff->getName(), key.todiff->getParent()); - - size_t argnum = 0; - for (Argument &Arg : NewF->args()) { - Arg.setName("arg" + Twine(argnum)); - ++argnum; - } - - BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - - auto &aug = CreateAugmentedPrimal( - context, key.todiff, key.retType, key.constant_args, TA, - key.returnUsed, key.shadowReturnUsed, key.typeInfo, - key.subsequent_calls_may_write, key.overwritten_args, - /*forceAnonymousTape*/ false, key.runtimeActivity, key.width, - key.AtomicAdd, omp); - - SmallVector fwdargs; - for (auto &a : NewF->args()) - fwdargs.push_back(&a); - if (key.retType == DIFFE_TYPE::OUT_DIFF) - fwdargs.pop_back(); - auto cal = bb.CreateCall(aug.fn, fwdargs); - cal->setCallingConv(aug.fn->getCallingConv()); - - llvm::Value *tape = nullptr; - - if (aug.returns.find(AugmentedStruct::Tape) != aug.returns.end()) { - auto tapeIdx = aug.returns.find(AugmentedStruct::Tape)->second; - tape = (tapeIdx == -1) ? cal : bb.CreateExtractValue(cal, tapeIdx); - if (tape->getType()->isEmptyTy()) - tape = UndefValue::get(tape->getType()); - } - - if (aug.tapeType) { - assert(tape); - auto tapep = bb.CreatePointerCast( - tape, PointerType::get( - aug.tapeType, - cast(tape->getType())->getAddressSpace())); - auto truetape = bb.CreateLoad(aug.tapeType, tapep, "tapeld"); - truetape->setMetadata("enzyme_mustcache", - MDNode::get(truetape->getContext(), {})); - - if (key.freeMemory) { - auto size = NewF->getParent()->getDataLayout().getTypeAllocSizeInBits( - aug.tapeType); - if (size != 0) { - CreateDealloc(bb, tape); - } - } - tape = truetape; - } - - auto revfn = CreatePrimalAndGradient( - context, - (ReverseCacheKey){ - .todiff = key.todiff, - .retType = key.retType, - .constant_args = key.constant_args, - .overwritten_args = key.overwritten_args, - .returnUsed = false, - .shadowReturnUsed = false, - .mode = DerivativeMode::ReverseModeGradient, - .width = key.width, - .freeMemory = key.freeMemory, - .AtomicAdd = key.AtomicAdd, - .additionalType = tape ? tape->getType() : nullptr, - .forceAnonymousTape = key.forceAnonymousTape, - .typeInfo = key.typeInfo, - .runtimeActivity = key.runtimeActivity, - }, - TA, &aug, omp); - - SmallVector revargs; - for (auto &a : NewF->args()) { - revargs.push_back(&a); - } - if (tape) { - revargs.push_back(tape); - } - auto revcal = bb.CreateCall(revfn, revargs); - revcal->setCallingConv(revfn->getCallingConv()); - - if (NewF->getReturnType()->isEmptyTy()) { - bb.CreateRet(UndefValue::get(NewF->getReturnType())); - } else if (NewF->getReturnType()->isVoidTy()) { - bb.CreateRetVoid(); - } else { - bb.CreateRet(revcal); - } - assert(!key.returnUsed); - - return insert_or_assign2( - ReverseCachedFunctions, key, NewF) - ->second; - } - - auto md = key.todiff->getMetadata("enzyme_gradient"); - if (!isa(md)) { - llvm::errs() << *key.todiff << "\n"; - llvm::errs() << *md << "\n"; - report_fatal_error( - "unknown gradient for noninvertible function -- metadata incorrect"); - } - auto md2 = cast(md); - assert(md2->getNumOperands() == 1); - auto gvemd = cast(md2->getOperand(0)); - auto foundcalled = cast(gvemd->getValue()); - - if (hasconstant) { - EmitWarning("NoCustom", *key.todiff, - "Massaging provided custom reverse pass"); - SmallVector dupargs; - std::vector next_constant_args(key.constant_args); - { - auto OFT = key.todiff->getFunctionType(); - for (size_t act_idx = 0; act_idx < key.constant_args.size(); - act_idx++) { - dupargs.push_back(OFT->getParamType(act_idx)); - switch (key.constant_args[act_idx]) { - case DIFFE_TYPE::OUT_DIFF: - break; - case DIFFE_TYPE::DUP_ARG: - case DIFFE_TYPE::DUP_NONEED: - dupargs.push_back(OFT->getParamType(act_idx)); - break; - case DIFFE_TYPE::CONSTANT: - if (!OFT->getParamType(act_idx)->isFPOrFPVectorTy()) { - next_constant_args[act_idx] = DIFFE_TYPE::DUP_ARG; - } else { - next_constant_args[act_idx] = DIFFE_TYPE::OUT_DIFF; - } - break; - } - } - } - - auto revfn = CreatePrimalAndGradient( - context, - (ReverseCacheKey){ - .todiff = key.todiff, - .retType = key.retType, - .constant_args = next_constant_args, - .overwritten_args = key.overwritten_args, - .returnUsed = key.returnUsed, - .shadowReturnUsed = false, - .mode = DerivativeMode::ReverseModeGradient, - .width = key.width, - .freeMemory = key.freeMemory, - .AtomicAdd = key.AtomicAdd, - .additionalType = nullptr, - .forceAnonymousTape = key.forceAnonymousTape, - .typeInfo = key.typeInfo, - .runtimeActivity = key.runtimeActivity, - }, - TA, augmenteddata, omp); - - { - auto arg = revfn->arg_begin(); - for (auto cidx : next_constant_args) { - arg++; - if (cidx == DIFFE_TYPE::DUP_ARG || cidx == DIFFE_TYPE::DUP_NONEED) - arg++; - } - while (arg != revfn->arg_end()) { - dupargs.push_back(arg->getType()); - arg++; - } - } - - FunctionType *FTy = - FunctionType::get(revfn->getReturnType(), dupargs, - revfn->getFunctionType()->isVarArg()); - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixgradient_" + key.todiff->getName(), key.todiff->getParent()); - - BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - auto arg = NewF->arg_begin(); - SmallVector revargs; - size_t act_idx = 0; - while (act_idx != key.constant_args.size()) { - arg->setName("arg" + Twine(act_idx)); - revargs.push_back(arg); - switch (key.constant_args[act_idx]) { - case DIFFE_TYPE::OUT_DIFF: - break; - case DIFFE_TYPE::DUP_ARG: - case DIFFE_TYPE::DUP_NONEED: - arg++; - arg->setName("arg" + Twine(act_idx) + "'"); - revargs.push_back(arg); - break; - case DIFFE_TYPE::CONSTANT: - if (next_constant_args[act_idx] != DIFFE_TYPE::OUT_DIFF) { - revargs.push_back(arg); - } - break; - } - arg++; - act_idx++; - } - size_t pa = 0; - while (arg != NewF->arg_end()) { - revargs.push_back(arg); - arg->setName("postarg" + Twine(pa)); - pa++; - arg++; - } - auto cal = bb.CreateCall(revfn, revargs); - cal->setCallingConv(revfn->getCallingConv()); - - if (NewF->getReturnType()->isEmptyTy()) - bb.CreateRet(UndefValue::get(NewF->getReturnType())); - else if (NewF->getReturnType()->isVoidTy()) - bb.CreateRetVoid(); - else - bb.CreateRet(cal); - - return insert_or_assign2( - ReverseCachedFunctions, key, NewF) - ->second; - } - - if (!key.returnUsed && key.freeMemory) { - auto res = - getDefaultFunctionTypeForGradient(key.todiff->getFunctionType(), - /*retType*/ key.retType); - assert(augmenteddata); - bool badDiffRet = false; - bool hasTape = true; - if (foundcalled->arg_size() == res.first.size() + 1 /*tape*/) { - auto lastarg = foundcalled->arg_end(); - lastarg--; - res.first.push_back(lastarg->getType()); - if (key.retType == DIFFE_TYPE::OUT_DIFF) { - lastarg--; - if (lastarg->getType() != key.todiff->getReturnType()) - badDiffRet = true; - } - } else if (foundcalled->arg_size() == res.first.size()) { - if (key.retType == DIFFE_TYPE::OUT_DIFF) { - auto lastarg = foundcalled->arg_end(); - lastarg--; - if (lastarg->getType() != key.todiff->getReturnType()) - badDiffRet = true; - } - hasTape = false; - // res.first.push_back(StructType::get(todiff->getContext(), {})); - } else { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "Bad function type of custom reverse pass for function " - << key.todiff->getName() << " of type " - << *key.todiff->getFunctionType() << "\n"; - ss << " expected gradient function to have argument types ["; - bool seen = false; - for (auto a : res.first) { - if (seen) - ss << ", "; - seen = true; - ss << *a; - } - ss << "]\n"; - ss << " Instead found " << foundcalled->getName() << " of type " - << *foundcalled->getFunctionType() << "\n"; - if (context.req) { - ss << " at context: " << *context.req; - } else { - ss << *key.todiff << "\n"; - } - if (!EmitNoDerivativeError(ss.str(), key.todiff, context)) { - assert(0 && "bad type for custom gradient"); - llvm_unreachable("bad type for custom gradient"); - } - } - - auto st = dyn_cast(foundcalled->getReturnType()); - bool wrongRet = - st == nullptr && !foundcalled->getReturnType()->isVoidTy(); - if (wrongRet || badDiffRet) { - // if (wrongRet || !hasTape) { - Type *FRetTy = - res.second.empty() - ? Type::getVoidTy(key.todiff->getContext()) - : StructType::get(key.todiff->getContext(), {res.second}); - - FunctionType *FTy = FunctionType::get( - FRetTy, res.first, key.todiff->getFunctionType()->isVarArg()); - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixgradient_" + key.todiff->getName(), key.todiff->getParent()); - NewF->setAttributes(foundcalled->getAttributes()); - if (NewF->hasFnAttribute(Attribute::NoInline)) { - NewF->removeFnAttr(Attribute::NoInline); - } - if (NewF->hasFnAttribute(Attribute::OptimizeNone)) { - NewF->removeFnAttr(Attribute::OptimizeNone); - } - size_t argnum = 0; - for (Argument &Arg : NewF->args()) { - if (Arg.hasAttribute(Attribute::Returned)) - Arg.removeAttr(Attribute::Returned); - if (Arg.hasAttribute(Attribute::StructRet)) - Arg.removeAttr(Attribute::StructRet); - Arg.setName("arg" + Twine(argnum)); - ++argnum; - } - - BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - SmallVector args; - for (auto &a : NewF->args()) - args.push_back(&a); - if (badDiffRet) { - auto idx = hasTape ? (args.size() - 2) : (args.size() - 1); - Type *T = (foundcalled->arg_begin() + idx)->getType(); - - auto AI = bb.CreateAlloca(T); - bb.CreateStore(args[idx], - bb.CreatePointerCast( - AI, PointerType::getUnqual(args[idx]->getType()))); - Value *vres = bb.CreateLoad(T, AI); - args[idx] = vres; - } - // if (!hasTape) { - // args.pop_back(); - //} - auto cal = bb.CreateCall(foundcalled, args); - cal->setCallingConv(foundcalled->getCallingConv()); - Value *val = cal; - if (wrongRet) { - auto ut = UndefValue::get(NewF->getReturnType()); - if (val->getType()->isEmptyTy() && res.second.size() == 0) { - val = ut; - } else if (res.second.size() == 1 && - res.second[0] == val->getType()) { - val = bb.CreateInsertValue(ut, cal, {0u}); - } else { - llvm::errs() << *foundcalled << "\n"; - assert(0 && "illegal type for reverse"); - llvm_unreachable("illegal type for reverse"); - } - } - if (val->getType()->isVoidTy()) - bb.CreateRetVoid(); - else - bb.CreateRet(val); - foundcalled = NewF; - } - return insert_or_assign2( - ReverseCachedFunctions, key, foundcalled) - ->second; - } - - EmitWarning("NoCustom", *key.todiff, - "Not using provided custom reverse pass as require either " - "return or non-constant"); - } - - if (augmenteddata && augmenteddata->constant_args != key.constant_args) { - llvm::errs() << " sz: " << augmenteddata->constant_args.size() << " " - << key.constant_args.size() << "\n"; - for (size_t i = 0; i < key.constant_args.size(); ++i) { - llvm::errs() << " i: " << i << " " - << to_string(augmenteddata->constant_args[i]) << " " - << to_string(key.constant_args[i]) << "\n"; - } - assert(augmenteddata->constant_args.size() == key.constant_args.size()); - assert(augmenteddata->constant_args == key.constant_args); - } - - ReturnType retVal = - key.returnUsed ? (key.shadowReturnUsed ? ReturnType::ArgsWithTwoReturns - : ReturnType::ArgsWithReturn) - : (key.shadowReturnUsed ? ReturnType::ArgsWithReturn - : ReturnType::Args); - - bool diffeReturnArg = key.retType == DIFFE_TYPE::OUT_DIFF; - - DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone( - *this, key.mode, key.runtimeActivity, key.width, key.todiff, TLI, TA, - oldTypeInfo, key.retType, - augmenteddata ? augmenteddata->shadowReturnUsed : key.shadowReturnUsed, - diffeReturnArg, key.constant_args, retVal, key.additionalType, omp); - - gutils->AtomicAdd = key.AtomicAdd; - gutils->FreeMemory = key.freeMemory; - insert_or_assign2(ReverseCachedFunctions, key, - gutils->newFunc); - - if (key.todiff->empty()) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No reverse pass found for " + key.todiff->getName() << "\n"; - if (context.req) { - ss << " at context: " << *context.req; - } else { - ss << *key.todiff << "\n"; - } - BasicBlock *entry = &gutils->newFunc->getEntryBlock(); - cleanupInversionAllocs(gutils, entry); - clearFunctionAttributes(gutils->newFunc); - if (EmitNoDerivativeError(ss.str(), key.todiff, context)) { - auto newFunc = gutils->newFunc; - delete gutils; - IRBuilder<> b(&*newFunc->getEntryBlock().begin()); - RequestContext context2{nullptr, &b}; - EmitNoDerivativeError(ss.str(), key.todiff, context2); - return newFunc; - } - llvm::errs() << "mod: " << *key.todiff->getParent() << "\n"; - llvm::errs() << *key.todiff << "\n"; - llvm_unreachable("attempting to differentiate function without definition"); - } - - if (augmenteddata && !augmenteddata->isComplete) { - auto nf = gutils->newFunc; - delete gutils; - assert(!prevFunction); - nf->setMetadata("enzyme_placeholder", MDTuple::get(nf->getContext(), {})); - return nf; - } - - const SmallPtrSet guaranteedUnreachable = - getGuaranteedUnreachable(gutils->oldFunc); - - // Convert uncacheable args from the input function to the preprocessed - // function - const std::vector &_overwritten_argsPP = key.overwritten_args; - - gutils->forceActiveDetection(); - - // requires is_value_needed_in_reverse, that needs unnecessaryValues - // sets backwardsOnlyShadows, rematerializableAllocations, and - // allocationsWithGuaranteedFrees - gutils->computeGuaranteedFrees(); - CacheAnalysis CA(gutils->allocationsWithGuaranteedFree, - gutils->rematerializableAllocations, gutils->TR, - *gutils->OrigAA, gutils->oldFunc, - PPC.FAM.getResult(*gutils->oldFunc), - *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, - key.subsequent_calls_may_write, _overwritten_argsPP, - key.mode, omp); - const std::map>> - overwritten_args_map = - (augmenteddata) ? augmenteddata->overwritten_args_map - : CA.compute_overwritten_args_for_callsites(); - gutils->overwritten_args_map_ptr = &overwritten_args_map; - - const std::map can_modref_map = - augmenteddata ? augmenteddata->can_modref_map - : CA.compute_uncacheable_load_map(); - gutils->can_modref_map = &can_modref_map; - - std::map, int> mapping; - if (augmenteddata) - mapping = augmenteddata->tapeIndices; - - auto getIndex = [&](Instruction *I, CacheType u, IRBuilder<> &B) -> unsigned { - return gutils->getIndex(std::make_pair(I, u), mapping, B); - }; - - // requires is_value_needed_in_reverse, that needs unnecessaryValues - // sets knownRecomputeHeuristic - gutils->computeMinCache(); - - // Requires knownRecomputeCache to be set as call to getContext - // itself calls createCacheForScope - gutils->forceAugmentedReturns(); - - SmallPtrSet unnecessaryValues; - SmallPtrSet unnecessaryInstructions; - calculateUnusedValuesInFunction(*gutils->oldFunc, unnecessaryValues, - unnecessaryInstructions, key.returnUsed, - key.mode, gutils, TLI, key.constant_args, - guaranteedUnreachable); - gutils->unnecessaryValuesP = &unnecessaryValues; - - SmallPtrSet unnecessaryStores; - calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, - unnecessaryInstructions, gutils, TLI); - - Value *additionalValue = nullptr; - if (key.additionalType) { - auto v = gutils->newFunc->arg_end(); - v--; - additionalValue = v; - assert(key.mode != DerivativeMode::ReverseModeCombined); - assert(augmenteddata); - - // TODO VERIFY THIS - if (augmenteddata->tapeType && (omp || key.forceAnonymousTape)) { - IRBuilder<> BuilderZ(gutils->inversionAllocs); - if (!augmenteddata->tapeType->isEmptyTy()) { - auto tapep = BuilderZ.CreatePointerCast( - additionalValue, - PointerType::get(augmenteddata->tapeType, - cast(additionalValue->getType()) - ->getAddressSpace())); - LoadInst *truetape = - BuilderZ.CreateLoad(augmenteddata->tapeType, tapep, "truetape"); - truetape->setMetadata("enzyme_mustcache", - MDNode::get(truetape->getContext(), {})); - - if (!omp && gutils->FreeMemory) { - CreateDealloc(BuilderZ, additionalValue); - } - additionalValue = truetape; - } else { - if (gutils->FreeMemory) { - CreateDealloc(BuilderZ, additionalValue); - } - additionalValue = UndefValue::get(augmenteddata->tapeType); - } - } - - // TODO here finish up making recursive structs simply pass in i8* - gutils->setTape(additionalValue); - } - - Argument *differetval = nullptr; - if (key.retType == DIFFE_TYPE::OUT_DIFF) { - auto endarg = gutils->newFunc->arg_end(); - endarg--; - if (key.additionalType) - endarg--; - differetval = endarg; - - if (!key.todiff->getReturnType()->isVoidTy()) { - if (!(differetval->getType() == - gutils->getShadowType(key.todiff->getReturnType()))) { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - } - assert(differetval->getType() == - gutils->getShadowType(key.todiff->getReturnType())); - } - } - - // Explicitly handle all returns first to ensure that return instructions know - // if they are used or not before - // processessing instructions - std::map replacedReturns; - llvm::AllocaInst *retAlloca = nullptr; - llvm::AllocaInst *dretAlloca = nullptr; - if (key.returnUsed) { - retAlloca = - IRBuilder<>(&gutils->newFunc->getEntryBlock().front()) - .CreateAlloca(key.todiff->getReturnType(), nullptr, "toreturn"); - } - if (key.shadowReturnUsed) { - assert(key.retType == DIFFE_TYPE::DUP_ARG || - key.retType == DIFFE_TYPE::DUP_NONEED); - assert(key.mode == DerivativeMode::ReverseModeCombined); - dretAlloca = - IRBuilder<>(&gutils->newFunc->getEntryBlock().front()) - .CreateAlloca(key.todiff->getReturnType(), nullptr, "dtoreturn"); - } - if (key.mode == DerivativeMode::ReverseModeCombined || - key.mode == DerivativeMode::ReverseModeGradient) { - for (BasicBlock &oBB : *gutils->oldFunc) { - if (ReturnInst *orig = dyn_cast(oBB.getTerminator())) { - ReturnInst *op = cast(gutils->getNewFromOriginal(orig)); - BasicBlock *BB = op->getParent(); - IRBuilder<> rb(op); - rb.setFastMathFlags(getFast()); - - if (retAlloca) { - StoreInst *si = rb.CreateStore( - gutils->getNewFromOriginal(orig->getReturnValue()), retAlloca); - replacedReturns[orig] = si; - } - - if (key.retType == DIFFE_TYPE::DUP_ARG || - key.retType == DIFFE_TYPE::DUP_NONEED) { - if (dretAlloca) { - rb.CreateStore(gutils->invertPointerM(orig->getReturnValue(), rb), - dretAlloca); - } - } else if (key.retType == DIFFE_TYPE::OUT_DIFF) { - assert(orig->getReturnValue()); - assert(differetval); - if (!gutils->isConstantValue(orig->getReturnValue())) { - IRBuilder<> reverseB(gutils->reverseBlocks[BB].back()); - gutils->setDiffe(orig->getReturnValue(), differetval, reverseB); - } - } else { - assert(dretAlloca == nullptr); - } - - rb.CreateBr(gutils->reverseBlocks[BB].front()); - gutils->erase(op); - } - } - } - - AdjointGenerator maker( - key.mode, gutils, key.constant_args, key.retType, getIndex, - overwritten_args_map, augmenteddata, &replacedReturns, unnecessaryValues, - unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable); - - for (BasicBlock &oBB : *gutils->oldFunc) { - // Don't create derivatives for code that results in termination - if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) { - auto newBB = cast(gutils->getNewFromOriginal(&oBB)); - SmallVector toRemove; - if (key.mode != DerivativeMode::ReverseModeCombined) { - if (auto II = dyn_cast(oBB.getTerminator())) { - toRemove.push_back(cast( - gutils->getNewFromOriginal(II->getNormalDest()))); - } else { - for (auto next : successors(&oBB)) { - auto sucBB = cast(gutils->getNewFromOriginal(next)); - toRemove.push_back(sucBB); - } - } - } - - for (auto sucBB : toRemove) { - if (sucBB->empty() || !isa(sucBB->begin())) - continue; - - SmallVector phis; - for (PHINode &Phi : sucBB->phis()) { - phis.push_back(&Phi); - } - for (PHINode *Phi : phis) { - unsigned NumPreds = Phi->getNumIncomingValues(); - if (NumPreds == 0) - continue; - Phi->removeIncomingValue(newBB); - } - } - - SmallVector toerase; - for (auto &I : oBB) { - toerase.push_back(&I); - } - for (auto I : llvm::reverse(toerase)) { - maker.eraseIfUnused(*I, /*erase*/ true, - /*check*/ key.mode == - DerivativeMode::ReverseModeCombined); - } - - if (key.mode != DerivativeMode::ReverseModeCombined) { - if (newBB->getTerminator()) - gutils->erase(newBB->getTerminator()); - IRBuilder<> builder(newBB); - builder.CreateUnreachable(); - } - continue; - } - - auto term = oBB.getTerminator(); - assert(term); - if (!isa(term) && !isa(term) && - !isa(term)) { - llvm::errs() << *oBB.getParent() << "\n"; - llvm::errs() << "unknown terminator instance " << *term << "\n"; - assert(0 && "unknown terminator inst"); - } - - BasicBlock::reverse_iterator I = oBB.rbegin(), E = oBB.rend(); - ++I; - for (; I != E; ++I) { - maker.visit(&*I); - assert(oBB.rend() == E); - } - - createInvertedTerminator(gutils, key.constant_args, &oBB, retAlloca, - dretAlloca, - 0 + (key.additionalType ? 1 : 0) + - ((key.retType == DIFFE_TYPE::DUP_ARG || - key.retType == DIFFE_TYPE::DUP_NONEED) - ? 1 - : 0), - key.retType); - } - - if (key.mode == DerivativeMode::ReverseModeGradient) - restoreCache(gutils, mapping, guaranteedUnreachable); - - gutils->eraseFictiousPHIs(); - - BasicBlock *entry = &gutils->newFunc->getEntryBlock(); - - auto Arch = - llvm::Triple(gutils->newFunc->getParent()->getTargetTriple()).getArch(); - unsigned int SharedAddrSpace = - Arch == Triple::amdgcn ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local - : 3; - - if (key.mode == DerivativeMode::ReverseModeCombined) { - BasicBlock *sharedBlock = nullptr; - for (auto &g : gutils->newFunc->getParent()->globals()) { - if (hasMetadata(&g, "enzyme_internalshadowglobal")) { - IRBuilder<> entryBuilder(gutils->inversionAllocs, - gutils->inversionAllocs->begin()); - - if ((Arch == Triple::nvptx || Arch == Triple::nvptx64 || - Arch == Triple::amdgcn) && - g.getType()->getAddressSpace() == SharedAddrSpace) { - if (sharedBlock == nullptr) - sharedBlock = BasicBlock::Create(entry->getContext(), "shblock", - gutils->newFunc); - entryBuilder.SetInsertPoint(sharedBlock); - } - auto store = entryBuilder.CreateStore( - Constant::getNullValue(g.getValueType()), &g); - if (g.getAlign()) - store->setAlignment(*g.getAlign()); - } - } - if (sharedBlock) { - BasicBlock *OldEntryInsts = entry->splitBasicBlock(entry->begin()); - entry->getTerminator()->eraseFromParent(); - IRBuilder<> ebuilder(entry); - - Value *tx, *ty, *tz; - if (Arch == Triple::nvptx || Arch == Triple::nvptx64) { - tx = ebuilder.CreateCall(getIntrinsicDeclaration( - gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_x)); - ty = ebuilder.CreateCall(getIntrinsicDeclaration( - gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_y)); - tz = ebuilder.CreateCall(getIntrinsicDeclaration( - gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_z)); - } else if (Arch == Triple::amdgcn) { - tx = ebuilder.CreateCall(getIntrinsicDeclaration( - gutils->newFunc->getParent(), Intrinsic::amdgcn_workitem_id_x)); - ty = ebuilder.CreateCall(getIntrinsicDeclaration( - gutils->newFunc->getParent(), Intrinsic::amdgcn_workitem_id_y)); - tz = ebuilder.CreateCall(getIntrinsicDeclaration( - gutils->newFunc->getParent(), Intrinsic::amdgcn_workitem_id_z)); - } else { - llvm_unreachable("unknown gpu architecture"); - } - Value *OrVal = ebuilder.CreateOr(ebuilder.CreateOr(tx, ty), tz); - - ebuilder.CreateCondBr( - ebuilder.CreateICmpEQ(OrVal, ConstantInt::get(OrVal->getType(), 0)), - sharedBlock, OldEntryInsts); - - IRBuilder<> instbuilder(OldEntryInsts, OldEntryInsts->begin()); - - auto BarrierInst = Arch == Triple::amdgcn - ? (llvm::Intrinsic::ID)Intrinsic::amdgcn_s_barrier - : (llvm::Intrinsic::ID)Intrinsic::nvvm_barrier0; - instbuilder.CreateCall( - getIntrinsicDeclaration(gutils->newFunc->getParent(), BarrierInst), - {}); - OldEntryInsts->moveAfter(entry); - sharedBlock->moveAfter(entry); - IRBuilder<> sbuilder(sharedBlock); - sbuilder.CreateBr(OldEntryInsts); - SmallVector AIs; - for (auto &I : *OldEntryInsts) { - if (auto AI = dyn_cast(&I)) - AIs.push_back(AI); - } - for (auto AI : AIs) - AI->moveBefore(entry->getFirstNonPHIOrDbgOrLifetime()); - entry = OldEntryInsts; - } - } - - cleanupInversionAllocs(gutils, entry); - clearFunctionAttributes(gutils->newFunc); - - if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - report_fatal_error("function failed verification (4)"); - } - - auto nf = gutils->newFunc; - delete gutils; - - PPC.LowerAllocAddr(nf); - - { - PreservedAnalyses PA; - PPC.FAM.invalidate(*nf, PA); - } - PPC.AlwaysInline(nf); - if (Arch == Triple::nvptx || Arch == Triple::nvptx64) - PPC.ReplaceReallocs(nf, /*mem2reg*/ true); - - if (prevFunction) { - prevFunction->replaceAllUsesWith(nf); - prevFunction->eraseFromParent(); - } - - // Do not run post processing optimizations if the body of an openmp - // parallel so the adjointgenerator can successfully extract the allocation - // and frees and hoist them into the parent. Optimizing before then may - // make the IR different to traverse, and thus impossible to find the allocs. - if (PostOpt && !omp) - PPC.optimizeIntermediate(nf); - if (EnzymePrint) { - llvm::errs() << *nf << "\n"; - } - return nf; -} - -Function *EnzymeLogic::CreateForwardDiff( - RequestContext context, Function *todiff, DIFFE_TYPE retType, - ArrayRef constant_args, TypeAnalysis &TA, bool returnUsed, - DerivativeMode mode, bool freeMemory, bool runtimeActivity, unsigned width, - llvm::Type *additionalArg, const FnTypeInfo &oldTypeInfo_, - bool subsequent_calls_may_write, const std::vector _overwritten_args, - const AugmentedReturn *augmenteddata, bool omp) { - - TimeTraceScope timeScope("CreateForwardDiff", todiff->getName()); - - assert(retType != DIFFE_TYPE::OUT_DIFF); - - assert(mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit || - mode == DerivativeMode::ForwardModeError); - - FnTypeInfo oldTypeInfo = preventTypeAnalysisLoops(oldTypeInfo_, todiff); - - if (retType != DIFFE_TYPE::CONSTANT) - assert(!todiff->getReturnType()->isVoidTy()); - - if (returnUsed) - assert(!todiff->getReturnType()->isVoidTy()); - - if (mode != DerivativeMode::ForwardMode && - mode != DerivativeMode::ForwardModeError) - assert(_overwritten_args.size() == todiff->arg_size()); - - ForwardCacheKey tup = {todiff, - retType, - constant_args, - subsequent_calls_may_write, - _overwritten_args, - returnUsed, - mode, - width, - additionalArg, - oldTypeInfo, - runtimeActivity}; - - if (ForwardCachedFunctions.find(tup) != ForwardCachedFunctions.end()) { - return ForwardCachedFunctions.find(tup)->second; - } - - TargetLibraryInfo &TLI = PPC.FAM.getResult(*todiff); - - // TODO change this to go by default function type assumptions - bool hasconstant = false; - for (auto v : constant_args) { - assert(v != DIFFE_TYPE::OUT_DIFF); - if (v == DIFFE_TYPE::CONSTANT) { - hasconstant = true; - break; - } - } - - if (auto md = hasMetadata(todiff, (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError) - ? "enzyme_derivative" - : "enzyme_splitderivative")) { - if (!isa(md)) { - llvm::errs() << *todiff << "\n"; - llvm::errs() << *md << "\n"; - report_fatal_error( - "unknown derivative for function -- metadata incorrect"); - } - auto md2 = cast(md); - assert(md2); - assert(md2->getNumOperands() == 1); - if (!md2->getOperand(0)) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "Failed to use custom forward mode derivative for " - << todiff->getName() << "\n"; - ss << " found metadata (but null op0) " << *md2 << "\n"; - EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, - ss.str()); - return ForwardCachedFunctions[tup] = nullptr; - } - if (!isa(md2->getOperand(0))) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "Failed to use custom forward mode derivative for " - << todiff->getName() << "\n"; - ss << " found metadata (but not constantasmetadata) " - << *md2->getOperand(0) << "\n"; - EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, - ss.str()); - return ForwardCachedFunctions[tup] = nullptr; - } - auto gvemd = cast(md2->getOperand(0)); - auto foundcalled = cast(gvemd->getValue()); - - if ((foundcalled->getReturnType()->isVoidTy() || - retType != DIFFE_TYPE::CONSTANT) && - !hasconstant && returnUsed) - return foundcalled; - - if (!foundcalled->getReturnType()->isVoidTy() && !hasconstant) { - if (returnUsed && retType == DIFFE_TYPE::CONSTANT) { - } - if (!returnUsed && retType != DIFFE_TYPE::CONSTANT && !hasconstant) { - FunctionType *FTy = FunctionType::get( - todiff->getReturnType(), foundcalled->getFunctionType()->params(), - foundcalled->getFunctionType()->isVarArg()); - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixderivative_" + todiff->getName(), todiff->getParent()); - for (auto pair : llvm::zip(NewF->args(), foundcalled->args())) { - std::get<0>(pair).setName(std::get<1>(pair).getName()); - } - - BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - SmallVector args; - for (auto &a : NewF->args()) - args.push_back(&a); - auto cal = bb.CreateCall(foundcalled, args); - cal->setCallingConv(foundcalled->getCallingConv()); - - bb.CreateRet(bb.CreateExtractValue(cal, 1)); - return ForwardCachedFunctions[tup] = NewF; - } - assert(returnUsed); - } - - SmallVector curTypes; - bool legal = true; - SmallVector nextConstantArgs; - for (auto tup : llvm::zip(todiff->args(), constant_args)) { - auto &arg = std::get<0>(tup); - curTypes.push_back(arg.getType()); - if (std::get<1>(tup) != DIFFE_TYPE::CONSTANT) { - curTypes.push_back(arg.getType()); - nextConstantArgs.push_back(std::get<1>(tup)); - continue; - } - auto TT = oldTypeInfo.Arguments.find(&arg)->second[{-1}]; - if (TT.isFloat()) { - nextConstantArgs.push_back(DIFFE_TYPE::DUP_ARG); - continue; - } else if (TT == BaseType::Integer) { - nextConstantArgs.push_back(DIFFE_TYPE::DUP_ARG); - continue; - } else { - legal = false; - break; - } - } - if (augmenteddata && augmenteddata->returns.find(AugmentedStruct::Tape) != - augmenteddata->returns.end()) { - assert(additionalArg); - curTypes.push_back(additionalArg); - } - if (legal) { - Type *RT = todiff->getReturnType(); - if (returnUsed && retType != DIFFE_TYPE::CONSTANT) { - RT = StructType::get(RT->getContext(), {RT, RT}); - } - if (!returnUsed && retType == DIFFE_TYPE::CONSTANT) { - RT = Type::getVoidTy(RT->getContext()); - } - - FunctionType *FTy = FunctionType::get( - RT, curTypes, todiff->getFunctionType()->isVarArg()); - - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixderivative_" + todiff->getName(), todiff->getParent()); - - auto foundArg = NewF->arg_begin(); - SmallVector nextArgs; - for (auto tup : llvm::zip(todiff->args(), constant_args)) { - nextArgs.push_back(foundArg); - auto &arg = std::get<0>(tup); - foundArg->setName(arg.getName()); - foundArg++; - if (std::get<1>(tup) != DIFFE_TYPE::CONSTANT) { - foundArg->setName(arg.getName() + "'"); - nextConstantArgs.push_back(std::get<1>(tup)); - nextArgs.push_back(foundArg); - foundArg++; - continue; - } - auto TT = oldTypeInfo.Arguments.find(&arg)->second[{-1}]; - if (TT.isFloat()) { - nextArgs.push_back(Constant::getNullValue(arg.getType())); - nextConstantArgs.push_back(DIFFE_TYPE::DUP_ARG); - continue; - } else if (TT == BaseType::Integer) { - nextArgs.push_back(nextArgs.back()); - nextConstantArgs.push_back(DIFFE_TYPE::DUP_ARG); - continue; - } else { - legal = false; - break; - } - } - if (augmenteddata && augmenteddata->returns.find(AugmentedStruct::Tape) != - augmenteddata->returns.end()) { - foundArg->setName("tapeArg"); - nextArgs.push_back(foundArg); - foundArg++; - } - - BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - auto cal = bb.CreateCall(foundcalled, nextArgs); - cal->setCallingConv(foundcalled->getCallingConv()); - - if (returnUsed && retType != DIFFE_TYPE::CONSTANT) { - bb.CreateRet(cal); - } else if (returnUsed) { - bb.CreateRet(bb.CreateExtractValue(cal, 0)); - } else if (retType != DIFFE_TYPE::CONSTANT) { - bb.CreateRet(bb.CreateExtractValue(cal, 1)); - } else { - bb.CreateRetVoid(); - } - - return ForwardCachedFunctions[tup] = NewF; - } - - EmitWarning("NoCustom", *todiff, - "Cannot use provided custom derivative pass"); - } - - bool retActive = retType != DIFFE_TYPE::CONSTANT; - - ReturnType retVal = - returnUsed ? (retActive ? ReturnType::TwoReturns : ReturnType::Return) - : (retActive ? ReturnType::Return : ReturnType::Void); - - bool diffeReturnArg = false; - - DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone( - *this, mode, runtimeActivity, width, todiff, TLI, TA, oldTypeInfo, - retType, - /*shadowReturn*/ retActive, diffeReturnArg, constant_args, retVal, - additionalArg, omp); - - insert_or_assign2(ForwardCachedFunctions, tup, - gutils->newFunc); - - if (todiff->empty()) { - std::string s; - llvm::raw_string_ostream ss(s); - if (mode == DerivativeMode::ForwardModeError) { - ss << "No forward mode error function found for " + todiff->getName() - << "\n"; - } else { - ss << "No forward mode derivative found for " + todiff->getName() << "\n"; - } - if (context.req) { - ss << " at context: " << *context.req; - } else { - ss << *todiff << "\n"; - } - BasicBlock *entry = &gutils->newFunc->getEntryBlock(); - cleanupInversionAllocs(gutils, entry); - clearFunctionAttributes(gutils->newFunc); - if (EmitNoDerivativeError(ss.str(), todiff, context)) { - auto newFunc = gutils->newFunc; - delete gutils; - return newFunc; - } - llvm::errs() << "mod: " << *todiff->getParent() << "\n"; - llvm::errs() << *todiff << "\n"; - llvm_unreachable("attempting to differentiate function without definition"); - } - gutils->FreeMemory = freeMemory; - - const SmallPtrSet guaranteedUnreachable = - getGuaranteedUnreachable(gutils->oldFunc); - - gutils->forceActiveDetection(); - - // TODO populate with actual unnecessaryInstructions once the dependency - // cycle with activity analysis is removed - SmallPtrSet unnecessaryInstructionsTmp; - for (auto BB : guaranteedUnreachable) { - for (auto &I : *BB) - unnecessaryInstructionsTmp.insert(&I); - } - if (mode == DerivativeMode::ForwardModeSplit) - gutils->computeGuaranteedFrees(); - - SmallPtrSet unnecessaryValues; - SmallPtrSet unnecessaryInstructions; - SmallPtrSet unnecessaryStores; - - AdjointGenerator *maker; - - std::unique_ptr> can_modref_map; - if (mode == DerivativeMode::ForwardModeSplit) { - std::vector _overwritten_argsPP = _overwritten_args; - - gutils->computeGuaranteedFrees(); - CacheAnalysis CA( - gutils->allocationsWithGuaranteedFree, - gutils->rematerializableAllocations, gutils->TR, *gutils->OrigAA, - gutils->oldFunc, - PPC.FAM.getResult(*gutils->oldFunc), - *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, - subsequent_calls_may_write, _overwritten_argsPP, mode, omp); - const std::map>> - overwritten_args_map = CA.compute_overwritten_args_for_callsites(); - gutils->overwritten_args_map_ptr = &overwritten_args_map; - can_modref_map = std::make_unique>( - CA.compute_uncacheable_load_map()); - gutils->can_modref_map = can_modref_map.get(); - - // requires is_value_needed_in_reverse, that needs unnecessaryValues - // sets knownRecomputeHeuristic - gutils->computeMinCache(); - - // Requires knownRecomputeCache to be set as call to getContext - // itself calls createCacheForScope - gutils->forceAugmentedReturns(); - - auto getIndex = [&](Instruction *I, CacheType u, - IRBuilder<> &B) -> unsigned { - assert(augmenteddata); - return gutils->getIndex(std::make_pair(I, u), augmenteddata->tapeIndices, - B); - }; - - calculateUnusedValuesInFunction( - *gutils->oldFunc, unnecessaryValues, unnecessaryInstructions, - returnUsed, mode, gutils, TLI, constant_args, guaranteedUnreachable); - gutils->unnecessaryValuesP = &unnecessaryValues; - - calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, - unnecessaryInstructions, gutils, TLI); - - maker = new AdjointGenerator(mode, gutils, constant_args, retType, getIndex, - overwritten_args_map, augmenteddata, nullptr, - unnecessaryValues, unnecessaryInstructions, - unnecessaryStores, guaranteedUnreachable); - - if (additionalArg) { - auto v = gutils->newFunc->arg_end(); - v--; - Value *additionalValue = v; - assert(augmenteddata); - - // TODO VERIFY THIS - if (augmenteddata->tapeType && - augmenteddata->tapeType != additionalValue->getType()) { - IRBuilder<> BuilderZ(gutils->inversionAllocs); - if (!augmenteddata->tapeType->isEmptyTy()) { - auto tapep = BuilderZ.CreatePointerCast( - additionalValue, PointerType::getUnqual(augmenteddata->tapeType)); - LoadInst *truetape = - BuilderZ.CreateLoad(augmenteddata->tapeType, tapep, "truetape"); - truetape->setMetadata("enzyme_mustcache", - MDNode::get(truetape->getContext(), {})); - - if (!omp && gutils->FreeMemory) { - CreateDealloc(BuilderZ, additionalValue); - } - additionalValue = truetape; - } else { - if (gutils->FreeMemory) { - auto size = gutils->newFunc->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(augmenteddata->tapeType); - if (size != 0) { - CreateDealloc(BuilderZ, additionalValue); - } - } - additionalValue = UndefValue::get(augmenteddata->tapeType); - } - } - - // TODO here finish up making recursive structs simply pass in i8* - gutils->setTape(additionalValue); - } - } else { - gutils->forceAugmentedReturns(); - calculateUnusedValuesInFunction( - *gutils->oldFunc, unnecessaryValues, unnecessaryInstructions, - returnUsed, mode, gutils, TLI, constant_args, guaranteedUnreachable); - gutils->unnecessaryValuesP = &unnecessaryValues; - - calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, - unnecessaryInstructions, gutils, TLI); - maker = new AdjointGenerator(mode, gutils, constant_args, retType, nullptr, - {}, nullptr, nullptr, unnecessaryValues, - unnecessaryInstructions, unnecessaryStores, - guaranteedUnreachable); - } - - for (BasicBlock &oBB : *gutils->oldFunc) { - // Don't create derivatives for code that results in termination - if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) { - for (auto &I : oBB) { - maker->eraseIfUnused(I, /*erase*/ true, /*check*/ true); - } - continue; - } - - auto term = oBB.getTerminator(); - assert(term); - if (!isa(term) && !isa(term) && - !isa(term)) { - llvm::errs() << *oBB.getParent() << "\n"; - llvm::errs() << "unknown terminator instance " << *term << "\n"; - assert(0 && "unknown terminator inst"); - } - - auto first = oBB.begin(); - auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end()); - for (auto it = first; it != last; ++it) { - maker->visit(&*it); - } - - createTerminator(gutils, &oBB, retType, retVal); - } - - if (mode == DerivativeMode::ForwardModeSplit && augmenteddata) - restoreCache(gutils, augmenteddata->tapeIndices, guaranteedUnreachable); - - gutils->eraseFictiousPHIs(); - - BasicBlock *entry = &gutils->newFunc->getEntryBlock(); - - auto Arch = - llvm::Triple(gutils->newFunc->getParent()->getTargetTriple()).getArch(); - - cleanupInversionAllocs(gutils, entry); - clearFunctionAttributes(gutils->newFunc); - - if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - report_fatal_error("function failed verification (4)"); - } - - auto nf = gutils->newFunc; - delete gutils; - delete maker; - - PPC.LowerAllocAddr(nf); - - { - PreservedAnalyses PA; - PPC.FAM.invalidate(*nf, PA); - } - PPC.AlwaysInline(nf); - if (Arch == Triple::nvptx || Arch == Triple::nvptx64) - PPC.ReplaceReallocs(nf, /*mem2reg*/ true); - - if (PostOpt) - PPC.optimizeIntermediate(nf); - if (EnzymePrint) { - llvm::errs() << *nf << "\n"; - } - return nf; -} - -static Value *floatValTruncate(IRBuilderBase &B, Value *v, - FloatTruncation truncation) { - if (truncation.isToFPRT()) - return v; - - Type *toTy = truncation.getToType(B.getContext()); - if (auto vty = dyn_cast(v->getType())) - toTy = VectorType::get(toTy, vty->getElementCount()); - return B.CreateFPTrunc(v, toTy, "enzyme_trunc"); -} - -static Value *floatValExpand(IRBuilderBase &B, Value *v, - FloatTruncation truncation) { - if (truncation.isToFPRT()) - return v; - - Type *fromTy = truncation.getFromType(B.getContext()); - if (auto vty = dyn_cast(v->getType())) - fromTy = VectorType::get(fromTy, vty->getElementCount()); - return B.CreateFPExt(v, fromTy, "enzyme_exp"); -} - -static Value *floatMemTruncate(IRBuilderBase &B, Value *v, - FloatTruncation truncation) { - if (isa(v->getType())) - report_fatal_error("vector operations not allowed in mem trunc mode"); - - Type *toTy = truncation.getToType(B.getContext()); - return B.CreateBitCast(v, toTy); -} - -static Value *floatMemExpand(IRBuilderBase &B, Value *v, - FloatTruncation truncation) { - if (isa(v->getType())) - report_fatal_error("vector operations not allowed in mem trunc mode"); - - Type *fromTy = truncation.getFromType(B.getContext()); - return B.CreateBitCast(v, fromTy); -} - -class TruncateUtils { -protected: - FloatTruncation truncation; - llvm::Module *M; - Type *fromType; - Type *toType; - LLVMContext &ctx; - -private: - std::string getOriginalFPRTName(std::string Name) { - return std::string(EnzymeFPRTOriginalPrefix) + truncation.mangleFrom() + - "_" + Name; - } - std::string getFPRTName(std::string Name) { - return std::string(EnzymeFPRTPrefix) + truncation.mangleFrom() + "_" + Name; - } - - // Creates a function which contains the original floating point operation. - // The user can use this to compare results against. - void createOriginalFPRTFunc(Instruction &I, std::string Name, - SmallVectorImpl &Args, - llvm::Type *RetTy) { - auto MangledName = getOriginalFPRTName(Name); - auto F = M->getFunction(MangledName); - if (!F) { - SmallVector ArgTypes; - for (auto Arg : Args) - ArgTypes.push_back(Arg->getType()); - FunctionType *FnTy = - FunctionType::get(RetTy, ArgTypes, /*is_vararg*/ false); - F = Function::Create(FnTy, Function::ExternalLinkage, MangledName, M); - } - if (F->isDeclaration()) { - BasicBlock *Entry = BasicBlock::Create(F->getContext(), "entry", F); - auto ClonedI = I.clone(); - for (unsigned It = 0; It < Args.size(); It++) - ClonedI->setOperand(It, F->getArg(It)); - auto Return = ReturnInst::Create(F->getContext(), ClonedI, Entry); - ClonedI->insertBefore(Return); - } - } - - Function *getFPRTFunc(std::string Name, SmallVectorImpl &Args, - llvm::Type *RetTy) { - auto MangledName = getFPRTName(Name); - auto F = M->getFunction(MangledName); - if (!F) { - SmallVector ArgTypes; - for (auto Arg : Args) - ArgTypes.push_back(Arg->getType()); - FunctionType *FnTy = - FunctionType::get(RetTy, ArgTypes, /*is_vararg*/ false); - F = Function::Create(FnTy, Function::ExternalLinkage, MangledName, M); - } - return F; - } - - CallInst *createFPRTGeneric(llvm::IRBuilderBase &B, std::string Name, - const SmallVectorImpl &ArgsIn, - llvm::Type *RetTy) { - SmallVector Args(ArgsIn.begin(), ArgsIn.end()); - Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); - Args.push_back(B.getInt64(truncation.getTo().significandWidth)); - Args.push_back(B.getInt64(truncation.getMode())); - auto FprtFunc = getFPRTFunc(Name, Args, RetTy); - return cast(B.CreateCall(FprtFunc, Args)); - } - -public: - TruncateUtils(FloatTruncation truncation, Module *M) - : truncation(truncation), M(M), ctx(M->getContext()) { - fromType = truncation.getFromType(ctx); - toType = truncation.getToType(ctx); - if (fromType == toType) - assert(truncation.isToFPRT()); - } - - Type *getFromType() { return fromType; } - - Type *getToType() { return toType; } - - CallInst *createFPRTConstCall(llvm::IRBuilderBase &B, Value *V) { - assert(V->getType() == getFromType()); - SmallVector Args; - Args.push_back(V); - return createFPRTGeneric(B, "const", Args, getToType()); - } - CallInst *createFPRTNewCall(llvm::IRBuilderBase &B, Value *V) { - assert(V->getType() == getFromType()); - SmallVector Args; - Args.push_back(V); - return createFPRTGeneric(B, "new", Args, getToType()); - } - CallInst *createFPRTGetCall(llvm::IRBuilderBase &B, Value *V) { - SmallVector Args; - Args.push_back(V); - return createFPRTGeneric(B, "get", Args, getToType()); - } - CallInst *createFPRTDeleteCall(llvm::IRBuilderBase &B, Value *V) { - SmallVector Args; - Args.push_back(V); - return createFPRTGeneric(B, "delete", Args, B.getVoidTy()); - } - CallInst *createFPRTOpCall(llvm::IRBuilderBase &B, llvm::Instruction &I, - llvm::Type *RetTy, - SmallVectorImpl &ArgsIn) { - std::string Name; - if (auto BO = dyn_cast(&I)) { - Name = "binop_" + std::string(BO->getOpcodeName()); - } else if (auto II = dyn_cast(&I)) { - auto FOp = II->getCalledFunction(); - assert(FOp); - Name = "intr_" + std::string(FOp->getName()); - for (auto &C : Name) - if (C == '.') - C = '_'; - } else if (auto CI = dyn_cast(&I)) { - if (auto F = CI->getCalledFunction()) - Name = "func_" + std::string(F->getName()); - else - llvm_unreachable( - "Unexpected indirect call inst for conversion to FPRT"); - } else if (auto CI = dyn_cast(&I)) { - Name = "fcmp_" + std::string(CI->getPredicateName(CI->getPredicate())); - } else { - llvm_unreachable("Unexpected instruction for conversion to FPRT"); - } - createOriginalFPRTFunc(I, Name, ArgsIn, RetTy); - return createFPRTGeneric(B, Name, ArgsIn, RetTy); - } -}; - -class TruncateGenerator : public llvm::InstVisitor, - public TruncateUtils { -private: - ValueToValueMapTy &originalToNewFn; - FloatTruncation truncation; - Function *oldFunc; - Function *newFunc; - TruncateMode mode; - EnzymeLogic &Logic; - LLVMContext &ctx; - -public: - TruncateGenerator(ValueToValueMapTy &originalToNewFn, - FloatTruncation truncation, Function *oldFunc, - Function *newFunc, EnzymeLogic &Logic) - : TruncateUtils(truncation, newFunc->getParent()), - originalToNewFn(originalToNewFn), truncation(truncation), - oldFunc(oldFunc), newFunc(newFunc), mode(truncation.getMode()), - Logic(Logic), ctx(newFunc->getContext()) {} - - void checkHandled(llvm::Instruction &inst) { - // TODO - // if (all_of(inst.getOperandList(), - // [&](Use *use) { return use->get()->getType() == fromType; })) - // todo(inst); - } - - // TODO - void handleTrunc(); - void hendleIntToFloat(); - void handleFloatToInt(); - - void visitInstruction(llvm::Instruction &inst) { - using namespace llvm; - - // TODO explicitly handle all instructions rather than using the catch all - // below - - switch (inst.getOpcode()) { - // #include "InstructionDerivatives.inc" - default: - break; - } - - checkHandled(inst); - } - - Value *truncate(IRBuilder<> &B, Value *v) { - switch (mode) { - case TruncMemMode: - if (isa(v)) - return createFPRTConstCall(B, v); - return floatMemTruncate(B, v, truncation); - case TruncOpMode: - case TruncOpFullModuleMode: - return floatValTruncate(B, v, truncation); - } - llvm_unreachable("Unknown trunc mode"); - } - - Value *expand(IRBuilder<> &B, Value *v) { - switch (mode) { - case TruncMemMode: - return floatMemExpand(B, v, truncation); - case TruncOpMode: - case TruncOpFullModuleMode: - return floatValExpand(B, v, truncation); - } - llvm_unreachable("Unknown trunc mode"); - } - - void todo(llvm::Instruction &I) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "cannot handle unknown instruction\n" << I; - if (CustomErrorHandler) { - IRBuilder<> Builder2(getNewFromOriginal(&I)); - CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoTruncate, - this, nullptr, wrap(&Builder2)); - return; - } else { - EmitFailure("NoTruncate", I.getDebugLoc(), &I, ss.str()); - return; - } - } - - void visitAllocaInst(llvm::AllocaInst &I) { return; } - void visitICmpInst(llvm::ICmpInst &I) { return; } - void visitFCmpInst(llvm::FCmpInst &CI) { - switch (mode) { - case TruncMemMode: { - auto LHS = getNewFromOriginal(CI.getOperand(0)); - auto RHS = getNewFromOriginal(CI.getOperand(1)); - if (LHS->getType() != getFromType()) - return; - - auto newI = getNewFromOriginal(&CI); - IRBuilder<> B(newI); - auto truncLHS = truncate(B, LHS); - auto truncRHS = truncate(B, RHS); - - SmallVector Args; - Args.push_back(LHS); - Args.push_back(RHS); - Instruction *nres; - if (truncation.isToFPRT()) - nres = createFPRTOpCall(B, CI, B.getInt1Ty(), Args); - else - nres = - cast(B.CreateFCmp(CI.getPredicate(), truncLHS, truncRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(nres); - newI->eraseFromParent(); - return; - } - case TruncOpMode: - case TruncOpFullModuleMode: - return; - } - } - void visitLoadInst(llvm::LoadInst &LI) { - auto alignment = LI.getAlign(); - visitLoadLike(LI, alignment); - } - void visitStoreInst(llvm::StoreInst &SI) { - auto align = SI.getAlign(); - visitCommonStore(SI, SI.getPointerOperand(), SI.getValueOperand(), align, - SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(), - /*mask=*/nullptr); - } - void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } - void visitPHINode(llvm::PHINode &phi) { return; } - void visitCastInst(llvm::CastInst &CI) { - switch (mode) { - case TruncMemMode: { - if (CI.getSrcTy() == getFromType() || CI.getDestTy() == getFromType()) - todo(CI); - return; - } - case TruncOpMode: - case TruncOpFullModuleMode: - return; - } - } - void visitSelectInst(llvm::SelectInst &SI) { - switch (mode) { - case TruncMemMode: { - auto newI = getNewFromOriginal(&SI); - IRBuilder<> B(newI); - auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); - auto newF = truncate(B, getNewFromOriginal(SI.getFalseValue())); - auto nres = cast( - B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - return; - } - case TruncOpMode: - case TruncOpFullModuleMode: - return; - } - llvm_unreachable(""); - } - void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } - void visitInsertElementInst(llvm::InsertElementInst &EEI) { return; } - void visitShuffleVectorInst(llvm::ShuffleVectorInst &EEI) { return; } - void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } - void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } - void visitBinaryOperator(llvm::BinaryOperator &BO) { - auto oldLHS = BO.getOperand(0); - auto oldRHS = BO.getOperand(1); - - if (oldLHS->getType() != getFromType() && - oldRHS->getType() != getFromType()) - return; - - switch (BO.getOpcode()) { - default: - break; - case BinaryOperator::Add: - case BinaryOperator::Sub: - case BinaryOperator::Mul: - case BinaryOperator::UDiv: - case BinaryOperator::SDiv: - case BinaryOperator::URem: - case BinaryOperator::SRem: - case BinaryOperator::AShr: - case BinaryOperator::LShr: - case BinaryOperator::Shl: - case BinaryOperator::And: - case BinaryOperator::Or: - case BinaryOperator::Xor: - assert(0 && "Invalid binop opcode for float arg"); - return; - } - - auto newI = getNewFromOriginal(&BO); - IRBuilder<> B(newI); - auto newLHS = truncate(B, getNewFromOriginal(oldLHS)); - auto newRHS = truncate(B, getNewFromOriginal(oldRHS)); - Instruction *nres = nullptr; - if (truncation.isToFPRT()) { - SmallVector Args({newLHS, newRHS}); - nres = createFPRTOpCall(B, BO, truncation.getToType(ctx), Args); - } else { - nres = cast(B.CreateBinOp(BO.getOpcode(), newLHS, newRHS)); - } - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - return; - } - void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); } - void visitMemSetCommon(llvm::CallInst &MS) { return; } - void visitMemTransferInst(llvm::MemTransferInst &MTI) { - using namespace llvm; - Value *isVolatile = getNewFromOriginal(MTI.getOperand(3)); - auto srcAlign = MTI.getSourceAlign(); - auto dstAlign = MTI.getDestAlign(); - visitMemTransferCommon(MTI.getIntrinsicID(), srcAlign, dstAlign, MTI, - MTI.getOperand(0), MTI.getOperand(1), - getNewFromOriginal(MTI.getOperand(2)), isVolatile); - } - void visitMemTransferCommon(llvm::Intrinsic::ID ID, llvm::MaybeAlign srcAlign, - llvm::MaybeAlign dstAlign, llvm::CallInst &MTI, - llvm::Value *orig_dst, llvm::Value *orig_src, - llvm::Value *new_size, llvm::Value *isVolatile) { - return; - } - void visitFenceInst(llvm::FenceInst &FI) { return; } - - bool handleIntrinsic(llvm::CallInst &CI, Intrinsic::ID ID) { - if (isDbgInfoIntrinsic(ID)) - return true; - - auto newI = cast(getNewFromOriginal(&CI)); - IRBuilder<> B(newI); - - SmallVector orig_ops(CI.arg_size()); - for (unsigned i = 0; i < CI.arg_size(); ++i) - orig_ops[i] = CI.getOperand(i); - - bool hasFromType = false; - SmallVector new_ops(CI.arg_size()); - for (unsigned i = 0; i < CI.arg_size(); ++i) { - if (orig_ops[i]->getType() == getFromType()) { - new_ops[i] = truncate(B, getNewFromOriginal(orig_ops[i])); - hasFromType = true; - } else { - new_ops[i] = getNewFromOriginal(orig_ops[i]); - } - } - Type *retTy = CI.getType(); - if (CI.getType() == getFromType()) { - hasFromType = true; - retTy = getToType(); - } - - if (!hasFromType) - return false; - - Instruction *intr = nullptr; - Value *nres = nullptr; - if (truncation.isToFPRT()) { - nres = intr = createFPRTOpCall(B, CI, retTy, new_ops); - } else { - // TODO check that the intrinsic is overloaded - nres = intr = - createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); - } - if (newI->getType() == getFromType()) - nres = expand(B, nres); - intr->copyIRFlags(newI); - newI->replaceAllUsesWith(nres); - newI->eraseFromParent(); - return true; - } - void visitIntrinsicInst(llvm::IntrinsicInst &II) { - handleIntrinsic(II, II.getIntrinsicID()); - } - - void visitReturnInst(llvm::ReturnInst &I) { return; } - - void visitBranchInst(llvm::BranchInst &I) { return; } - void visitSwitchInst(llvm::SwitchInst &I) { return; } - void visitUnreachableInst(llvm::UnreachableInst &I) { return; } - void visitLoadLike(llvm::Instruction &I, llvm::MaybeAlign alignment, - llvm::Value *mask = nullptr, - llvm::Value *orig_maskInit = nullptr) { - return; - } - - void visitCommonStore(llvm::Instruction &I, llvm::Value *orig_ptr, - llvm::Value *orig_val, llvm::MaybeAlign prevalign, - bool isVolatile, llvm::AtomicOrdering ordering, - llvm::SyncScope::ID syncScope, llvm::Value *mask) { - return; - } - - bool - handleAdjointForIntrinsic(llvm::Intrinsic::ID ID, llvm::Instruction &I, - llvm::SmallVectorImpl &orig_ops) { - using namespace llvm; - - if (isNVLoad(&I)) { - auto CI = cast(I.getOperand(1)); - visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue())); - return true; - } - - if (ID == Intrinsic::masked_store) { - auto align0 = cast(I.getOperand(2))->getZExtValue(); - auto align = MaybeAlign(align0); - visitCommonStore(I, /*orig_ptr*/ I.getOperand(1), - /*orig_val*/ I.getOperand(0), align, - /*isVolatile*/ false, llvm::AtomicOrdering::NotAtomic, - SyncScope::SingleThread, - /*mask*/ getNewFromOriginal(I.getOperand(3))); - return true; - } - if (ID == Intrinsic::masked_load) { - auto align0 = cast(I.getOperand(1))->getZExtValue(); - auto align = MaybeAlign(align0); - visitLoadLike(I, align, - /*mask*/ getNewFromOriginal(I.getOperand(2)), - /*orig_maskInit*/ I.getOperand(3)); - return true; - } - - return false; - } - - llvm::Value *getNewFromOriginal(llvm::Value *v) { - auto found = originalToNewFn.find(v); - assert(found != originalToNewFn.end()); - return found->second; - } - - llvm::Instruction *getNewFromOriginal(llvm::Instruction *v) { - return cast(getNewFromOriginal((llvm::Value *)v)); - } - - bool handleKnownCalls(llvm::CallInst &call, llvm::Function *called, - llvm::StringRef funcName, - llvm::CallInst *const newCall) { - return false; - } - - Value *GetShadow(RequestContext &ctx, Value *v) { - if (auto F = dyn_cast(v)) - return Logic.CreateTruncateFunc(ctx, F, truncation, mode); - llvm::errs() << " unknown get truncated func: " << *v << "\n"; - llvm_unreachable("unknown get truncated func"); - return v; - } - // Return - void visitCallInst(llvm::CallInst &CI) { - Intrinsic::ID ID; - StringRef funcName = getFuncNameFromCall(const_cast(&CI)); - if (isMemFreeLibMFunction(funcName, &ID)) - if (handleIntrinsic(CI, ID)) - return; - - using namespace llvm; - - CallInst *const newCall = cast(getNewFromOriginal(&CI)); - IRBuilder<> BuilderZ(newCall); - - if (auto called = CI.getCalledFunction()) - if (handleKnownCalls(CI, called, getFuncNameFromCall(&CI), newCall)) - return; - - if (mode != TruncOpFullModuleMode) { - RequestContext ctx(&CI, &BuilderZ); - auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); - newCall->setCalledOperand(val); - } - return; - } - void visitFPTruncInst(FPTruncInst &I) { return; } - void visitFPExtInst(FPExtInst &I) { return; } - void visitFPToUIInst(FPToUIInst &I) { return; } - void visitFPToSIInst(FPToSIInst &I) { return; } - void visitUIToFPInst(UIToFPInst &I) { return; } - void visitSIToFPInst(SIToFPInst &I) { return; } -}; - -bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, - FloatRepresentation from, - FloatRepresentation to, bool isTruncate) { - assert(context.req && context.ip); - - IRBuilderBase &B = *context.ip; - - Value *converted = nullptr; - auto truncation = FloatTruncation(from, to, TruncMemMode); - TruncateUtils TU(truncation, B.GetInsertBlock()->getParent()->getParent()); - if (isTruncate) - converted = TU.createFPRTNewCall(B, v); - else - converted = TU.createFPRTGetCall(B, v); - assert(converted); - - context.req->replaceAllUsesWith(converted); - context.req->eraseFromParent(); - - return true; -} - -llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, - llvm::Function *totrunc, - FloatTruncation truncation, - TruncateMode mode) { - TruncateCacheKey tup(totrunc, truncation, mode); - if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { - return TruncateCachedFunctions.find(tup)->second; - } - - FunctionType *orig_FTy = totrunc->getFunctionType(); - SmallVector params; - - for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) { - params.push_back(orig_FTy->getParamType(i)); - } - - Type *NewTy = totrunc->getReturnType(); - - FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); - std::string truncName = - std::string("__enzyme_done_truncate_") + truncateModeStr(mode) + - "_func_" + truncation.mangleTruncation() + "_" + totrunc->getName().str(); - Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, - totrunc->getParent()); - - if (mode != TruncOpFullModuleMode) - NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - - TruncateCachedFunctions[tup] = NewF; - - if (totrunc->empty()) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No truncate mode found for " + totrunc->getName() << "\n"; - llvm::Value *toshow = totrunc; - if (context.req) { - toshow = context.req; - ss << " at context: " << *context.req; - } else { - ss << *totrunc << "\n"; - } - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(toshow), - ErrorType::NoDerivative, nullptr, wrap(totrunc), - wrap(context.ip)); - return NewF; - } - if (context.req) { - EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, - ss.str()); - return NewF; - } - llvm::errs() << "mod: " << *totrunc->getParent() << "\n"; - llvm::errs() << *totrunc << "\n"; - llvm_unreachable("attempting to truncate function without definition"); - } - - ValueToValueMapTy originalToNewFn; - - for (auto i = totrunc->arg_begin(), j = NewF->arg_begin(); - i != totrunc->arg_end();) { - originalToNewFn[i] = j; - j->setName(i->getName()); - ++j; - ++i; - } - - SmallVector Returns; - CloneFunctionInto(NewF, totrunc, originalToNewFn, - CloneFunctionChangeType::LocalChangesOnly, Returns, "", - nullptr); - - NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - - TruncateGenerator handle(originalToNewFn, truncation, totrunc, NewF, *this); - for (auto &BB : *totrunc) - for (auto &I : BB) - handle.visit(&I); - - if (llvm::verifyFunction(*NewF, &llvm::errs())) { - llvm::errs() << *totrunc << "\n"; - llvm::errs() << *NewF << "\n"; - report_fatal_error("function failed verification (5)"); - } - - return NewF; -} - -llvm::Function *EnzymeLogic::CreateBatch(RequestContext context, - Function *tobatch, unsigned width, - ArrayRef arg_types, - BATCH_TYPE ret_type) { - - BatchCacheKey tup = std::make_tuple(tobatch, width, arg_types, ret_type); - if (BatchCachedFunctions.find(tup) != BatchCachedFunctions.end()) { - return BatchCachedFunctions.find(tup)->second; - } - - FunctionType *orig_FTy = tobatch->getFunctionType(); - SmallVector params; - unsigned long numVecParams = - std::count(arg_types.begin(), arg_types.end(), BATCH_TYPE::VECTOR); - - for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) { - if (arg_types[i] == BATCH_TYPE::VECTOR) { - Type *ty = GradientUtils::getShadowType(orig_FTy->getParamType(i), width); - params.push_back(ty); - } else { - params.push_back(orig_FTy->getParamType(i)); - } - } - - Type *NewTy = GradientUtils::getShadowType(tobatch->getReturnType(), width); - - FunctionType *FTy = FunctionType::get(NewTy, params, tobatch->isVarArg()); - Function *NewF = - Function::Create(FTy, tobatch->getLinkage(), - "batch_" + tobatch->getName(), tobatch->getParent()); - - if (tobatch->empty()) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No batch mode found for " + tobatch->getName() << "\n"; - llvm::Value *toshow = tobatch; - if (context.req) { - toshow = context.req; - ss << " at context: " << *context.req; - } else { - ss << *tobatch << "\n"; - } - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(toshow), - ErrorType::NoDerivative, nullptr, wrap(tobatch), - wrap(context.ip)); - return NewF; - } - if (context.req) { - EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, - ss.str()); - return NewF; - } - llvm::errs() << "mod: " << *tobatch->getParent() << "\n"; - llvm::errs() << *tobatch << "\n"; - llvm_unreachable("attempting to batch function without definition"); - } - - NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - - ValueToValueMapTy originalToNewFn; - - // Create placeholder for the old arguments - BasicBlock *placeholderBB = - BasicBlock::Create(NewF->getContext(), "placeholders", NewF); - - IRBuilder<> PlaceholderBuilder(placeholderBB); - PlaceholderBuilder.SetCurrentDebugLocation(DebugLoc()); - ValueToValueMapTy vmap; - auto DestArg = NewF->arg_begin(); - auto SrcArg = tobatch->arg_begin(); - - for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) { - Argument *arg = SrcArg; - if (arg_types[i] == BATCH_TYPE::VECTOR) { - auto placeholder = PlaceholderBuilder.CreatePHI( - arg->getType(), 0, "placeholder." + arg->getName()); - vmap[arg] = placeholder; - } else { - vmap[arg] = DestArg; - } - DestArg->setName(arg->getName()); - DestArg++; - SrcArg++; - } - - SmallVector Returns; - CloneFunctionInto(NewF, tobatch, vmap, - CloneFunctionChangeType::LocalChangesOnly, Returns, "", - nullptr); - - NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - - // find instructions to vectorize (going up / overestimation) - SmallPtrSet toVectorize; - SetVector> refinelist; - - for (unsigned i = 0; i < tobatch->getFunctionType()->getNumParams(); i++) { - if (arg_types[i] == BATCH_TYPE::VECTOR) { - Argument *arg = tobatch->arg_begin() + i; - toVectorize.insert(arg); - } - } - - for (auto &BB : *tobatch) - for (auto &Inst : BB) { - toVectorize.insert(&Inst); - refinelist.insert(&Inst); - } - - // find scalar instructions - while (!refinelist.empty()) { - Value *todo = *refinelist.begin(); - refinelist.erase(refinelist.begin()); - - if (isa(todo) && ret_type == BATCH_TYPE::VECTOR) - continue; - - if (auto branch_inst = dyn_cast(todo)) { - if (!branch_inst->isConditional()) { - toVectorize.erase(todo); - continue; - } - } - - if (auto call_inst = dyn_cast(todo)) { - if (call_inst->getFunctionType()->isVoidTy() && - call_inst->getFunctionType()->getNumParams() == 0) - toVectorize.erase(todo); - continue; - } - - if (auto todo_inst = dyn_cast(todo)) { - - if (todo_inst->mayReadOrWriteMemory()) - continue; - - if (isa(todo_inst)) - continue; - - SetVector> toCheck; - toCheck.insert(todo_inst->op_begin(), todo_inst->op_end()); - SmallPtrSet safe; - bool legal = true; - while (!toCheck.empty()) { - Value *cur = *toCheck.begin(); - toCheck.erase(toCheck.begin()); - - if (!std::get<1>(safe.insert(cur))) - continue; - - if (toVectorize.count(cur) == 0) - continue; - - if (Instruction *cur_inst = dyn_cast(cur)) { - if (!isa(cur_inst) && !cur_inst->mayReadOrWriteMemory()) { - for (auto &op : cur_inst->operands()) - toCheck.insert(op); - continue; - } - } - - legal = false; - break; - } - - if (legal) - if (toVectorize.erase(todo)) - for (auto user : todo_inst->users()) - refinelist.insert(user); - } - } - - // unwrap arguments - ValueMap> vectorizedValues; - auto entry = std::next(NewF->begin()); - IRBuilder<> Builder2(entry->getFirstNonPHI()); - Builder2.SetCurrentDebugLocation(DebugLoc()); - for (unsigned i = 0; i < FTy->getNumParams(); ++i) { - Argument *orig_arg = tobatch->arg_begin() + i; - Argument *arg = NewF->arg_begin() + i; - - if (arg_types[i] == BATCH_TYPE::SCALAR) { - originalToNewFn[tobatch->arg_begin() + i] = arg; - continue; - } - - Instruction *placeholder = cast(vmap[orig_arg]); - - for (unsigned j = 0; j < width; ++j) { - ExtractValueInst *argVecElem = - cast(Builder2.CreateExtractValue( - arg, {j}, - "unwrap" + (orig_arg->hasName() - ? "." + orig_arg->getName() + Twine(j) - : ""))); - if (j == 0) { - placeholder->replaceAllUsesWith(argVecElem); - placeholder->eraseFromParent(); - } - vectorizedValues[orig_arg].push_back(argVecElem); - } - } - - placeholderBB->eraseFromParent(); - - // update mapping with cloned basic blocks - for (auto i = tobatch->begin(), j = NewF->begin(); - i != tobatch->end() && j != NewF->end(); ++i, ++j) { - originalToNewFn[&*i] = &*j; - } - - // update mapping with cloned scalar values and the first vectorized values - auto J = inst_begin(NewF); - // skip the unwrapped vector params - std::advance(J, width * numVecParams); - for (auto I = inst_begin(tobatch); - I != inst_end(tobatch) && J != inst_end(NewF); ++I) { - if (toVectorize.count(&*I) != 0) { - vectorizedValues[&*I].push_back(&*J); - ++J; - } else { - originalToNewFn[&*I] = &*J; - ++J; - } - } - - // create placeholders for vector instructions 1..isVoidTy()) - continue; - - auto found = vectorizedValues.find(&I); - if (found != vectorizedValues.end()) { - Instruction *new_val_1 = cast(found->second.front()); - if (I.hasName()) - new_val_1->setName(I.getName() + "0"); - Instruction *insertPoint = - new_val_1->getNextNode() ? new_val_1->getNextNode() : new_val_1; - IRBuilder<> Builder2(insertPoint); - Builder2.SetCurrentDebugLocation(DebugLoc()); -#if LLVM_VERSION_MAJOR >= 18 - auto It = Builder2.GetInsertPoint(); - It.setHeadBit(true); - Builder2.SetInsertPoint(It); -#endif - for (unsigned i = 1; i < width; ++i) { - PHINode *placeholder = Builder2.CreatePHI(I.getType(), 0); - vectorizedValues[&I].push_back(placeholder); - if (I.hasName()) - placeholder->setName("placeholder." + I.getName() + Twine(i)); - } - } - } - } - - InstructionBatcher *batcher = - new InstructionBatcher(tobatch, NewF, width, vectorizedValues, - originalToNewFn, toVectorize, *this); - - for (auto val : toVectorize) { - if (auto inst = dyn_cast(val)) { - batcher->visit(inst); - if (batcher->hasError) - break; - } - } - - if (batcher->hasError) { - delete batcher; - NewF->eraseFromParent(); - return BatchCachedFunctions[tup] = nullptr; - } - - if (llvm::verifyFunction(*NewF, &llvm::errs())) { - llvm::errs() << *tobatch << "\n"; - llvm::errs() << *NewF << "\n"; - report_fatal_error("function failed verification (4)"); - } - - delete batcher; - - return BatchCachedFunctions[tup] = NewF; -}; - -llvm::Function * -EnzymeLogic::CreateTrace(RequestContext context, llvm::Function *totrace, - const SmallPtrSetImpl &sampleFunctions, - const SmallPtrSetImpl &observeFunctions, - const StringSet<> &ActiveRandomVariables, - ProbProgMode mode, bool autodiff, - TraceInterface *interface) { - TraceCacheKey tup(totrace, mode, autodiff, interface); - if (TraceCachedFunctions.find(tup) != TraceCachedFunctions.end()) { - return TraceCachedFunctions.find(tup)->second; - } - - // Determine generative functions - SmallPtrSet GenerativeFunctions; - SetVector> workList; - workList.insert(sampleFunctions.begin(), sampleFunctions.end()); - workList.insert(observeFunctions.begin(), observeFunctions.end()); - GenerativeFunctions.insert(sampleFunctions.begin(), sampleFunctions.end()); - GenerativeFunctions.insert(observeFunctions.begin(), observeFunctions.end()); - - while (!workList.empty()) { - auto todo = *workList.begin(); - workList.erase(workList.begin()); - - for (auto &&U : todo->uses()) { - if (auto &&call = dyn_cast(U.getUser())) { - auto &&fun = call->getParent()->getParent(); - auto &&[it, inserted] = GenerativeFunctions.insert(fun); - if (inserted) - workList.insert(fun); - } - } - } - - ValueToValueMapTy originalToNewFn; - TraceUtils *tutils = - TraceUtils::FromClone(mode, sampleFunctions, observeFunctions, interface, - totrace, originalToNewFn); - TraceGenerator *tracer = - new TraceGenerator(*this, tutils, autodiff, originalToNewFn, - GenerativeFunctions, ActiveRandomVariables); - - if (totrace->empty()) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No tracer found for " + totrace->getName() << "\n"; - llvm::Value *toshow = totrace; - if (context.req) { - toshow = context.req; - ss << " at context: " << *context.req; - } else { - ss << *totrace << "\n"; - } - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(toshow), - ErrorType::NoDerivative, nullptr, wrap(totrace), - wrap(context.ip)); - auto newFunc = tutils->newFunc; - delete tracer; - delete tutils; - return newFunc; - } - if (context.req) { - EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, - ss.str()); - auto newFunc = tutils->newFunc; - delete tracer; - delete tutils; - return newFunc; - } - llvm::errs() << "mod: " << *totrace->getParent() << "\n"; - llvm::errs() << *totrace << "\n"; - llvm_unreachable("attempting to trace function without definition"); - } - - tracer->visit(totrace); - - if (verifyFunction(*tutils->newFunc, &errs())) { - errs() << *totrace << "\n"; - errs() << *tutils->newFunc << "\n"; - report_fatal_error("function failed verification (4)"); - } - - Function *NewF = tutils->newFunc; - - delete tracer; - delete tutils; - - if (!autodiff) { - PPC.AlwaysInline(NewF); - - if (PostOpt) - PPC.optimizeIntermediate(NewF); - if (EnzymePrint) { - errs() << *NewF << "\n"; - } - } - - return TraceCachedFunctions[tup] = NewF; -} - -llvm::Value *EnzymeLogic::CreateNoFree(RequestContext context, - llvm::Value *todiff) { - if (isa(todiff)) - return todiff; - else if (auto F = dyn_cast(todiff)) - return CreateNoFree(context, F); - if (auto castinst = dyn_cast(todiff)) - if (castinst->isCast()) { - llvm::Constant *reps[] = { - cast(CreateNoFree(context, castinst->getOperand(0)))}; - return castinst->getWithOperands(reps); - } - - // Alloca/allocations are unsafe here since one could store freeing functions - // into them. For now we will be unsafe regarding indirect function call - // frees. - if (isa(todiff)) - return todiff; - - std::string demangledCall; - - { - Value *mdiff = todiff; - while (auto LI = dyn_cast(mdiff)) { - mdiff = LI->getPointerOperand(); - } - - if (auto CI = dyn_cast(todiff)) { - if (auto F = CI->getCalledFunction()) { - - // clang-format off - const char* NoFreeDemanglesStartsWith[] = { - "std::__u::locale::use_facet(std::__u::locale::id&) const", - }; - // clang-format on - - demangledCall = llvm::demangle(F->getName().str()); - // replace all '> >' with '>>' - size_t start = 0; - while ((start = demangledCall.find("> >", start)) != - std::string::npos) { - demangledCall.replace(start, 3, ">>"); - } - - for (auto Name : NoFreeDemanglesStartsWith) - if (startsWith(demangledCall, Name)) - return CI; - } - } - } - - // clang-format off - const char* NoFreeDemanglesStartsWith[] = { - "std::basic_ostream>& std::__ostream_insert>", - "std::basic_ostream>::operator<<", - "std::ostream::operator<<", - "std::ostream& std::ostream::_M_insert", - "std::basic_ostream>& std::__ostream_insert", - }; - // clang-format on - - if (auto CI = dyn_cast(todiff)) { - TargetLibraryInfo &TLI = - PPC.FAM.getResult(*CI->getParent()->getParent()); - if (isAllocationFunction(getFuncNameFromCall(CI), TLI)) - return CI; - if (auto F = CI->getCalledFunction()) { - - demangledCall = llvm::demangle(F->getName().str()); - // replace all '> >' with '>>' - size_t start = 0; - while ((start = demangledCall.find("> >", start)) != std::string::npos) { - demangledCall.replace(start, 3, ">>"); - } - - for (auto Name : NoFreeDemanglesStartsWith) - if (startsWith(demangledCall, Name)) - return CI; - } - } - if (auto PN = dyn_cast(todiff)) { - Value *illegal = nullptr; - for (auto &op : PN->incoming_values()) { - - if (auto CI = dyn_cast(op)) { - TargetLibraryInfo &TLI = PPC.FAM.getResult( - *CI->getParent()->getParent()); - if (isAllocationFunction(getFuncNameFromCall(CI), TLI)) - continue; - if (auto F = CI->getCalledFunction()) { - - demangledCall = llvm::demangle(F->getName().str()); - // replace all '> >' with '>>' - size_t start = 0; - while ((start = demangledCall.find("> >", start)) != - std::string::npos) { - demangledCall.replace(start, 3, ">>"); - } - - bool legal = false; - for (auto Name : NoFreeDemanglesStartsWith) - if (startsWith(demangledCall, Name)) { - legal = true; - break; - } - if (!legal) { - illegal = op; - break; - } - } - continue; - } - demangledCall = ""; - illegal = op; - break; - } - if (!illegal) - return PN; - } - - if (auto GV = dyn_cast(todiff)) { - if (GV->getName() == "_ZSt4cerr") - return GV; - if (GV->getName() == "_ZSt4cout") - return GV; - if (GV->getName() == "_ZNSt3__u5wcoutE") - return GV; - } - - if (context.ip) { - if (auto LI = dyn_cast(todiff)) { - if (auto smpl = simplifyLoad(LI)) - return CreateNoFree(context, smpl); - auto prev = CreateNoFree(context, LI->getPointerOperand()); - if (prev == LI->getPointerOperand()) - return todiff; - auto res = cast(context.ip->CreateLoad(LI->getType(), prev)); - res->copyMetadata(*LI); - return res; - } - if (auto CI = dyn_cast(todiff)) { - auto prev = CreateNoFree(context, CI->getOperand(0)); - if (prev == CI->getOperand(0)) - return todiff; - auto res = cast( - context.ip->CreateCast(CI->getOpcode(), prev, CI->getType())); - res->copyMetadata(*CI); - return res; - } - if (auto gep = dyn_cast(todiff)) { - if (gep->hasAllConstantIndices() || gep->isInBounds()) { - auto prev = CreateNoFree(context, gep->getPointerOperand()); - if (prev == gep->getPointerOperand()) - return todiff; - SmallVector idxs; - for (auto &ind : gep->indices()) - idxs.push_back(ind); - auto res = cast( - context.ip->CreateGEP(gep->getSourceElementType(), prev, idxs)); - res->setIsInBounds(gep->isInBounds()); - res->copyMetadata(*gep); - return res; - } - } - } - - if (EnzymeAssumeUnknownNoFree) { - return todiff; - } - - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No create nofree of unknown value\n"; - ss << *todiff << "\n"; - if (auto PN = dyn_cast(todiff)) { - for (auto &op : PN->incoming_values()) { - ss << " - " << *op << "\n"; - } - } - if (demangledCall.size()) { - ss << " demangled (" << demangledCall << ")\n"; - } - if (context.req) { - ss << " at context: " << *context.req; - } - if (auto I = dyn_cast(todiff)) { - auto fname = I->getParent()->getParent()->getName(); - if (startsWith(fname, "nofree_")) - fname = fname.substr(7); - std::string demangledName = llvm::demangle(fname.str()); - // replace all '> >' with '>>' - size_t start = 0; - while ((start = demangledName.find("> >", start)) != std::string::npos) { - demangledName.replace(start, 3, ">>"); - } - ss << " within func " << fname << " (" << demangledName << ")\n"; - } - if (EmitNoDerivativeError(ss.str(), todiff, context)) { - return todiff; - } - - llvm::errs() << s; - llvm_unreachable("unhandled, create no free"); -} - -llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { - if (NoFreeCachedFunctions.find(F) != NoFreeCachedFunctions.end()) { - return NoFreeCachedFunctions.find(F)->second; - } - bool hasNoFree = false; - hasNoFree |= F->hasFnAttribute(Attribute::NoFree); - if (hasNoFree) - return F; - - TargetLibraryInfo &TLI = PPC.FAM.getResult(*F); - - if (isAllocationFunction(F->getName(), TLI)) - return F; - - // clang-format off - StringSet<> NoFreeDemangles = { - "std::__u::basic_istream>::~basic_istream()", - "std::__u::basic_filebuf>::~basic_filebuf()", - "std::__u::basic_ostream>::~basic_ostream()", - "std::__u::basic_streambuf>::pubsync()", - "std::__u::basic_ostream>::write(char const*, long)", - "std::__u::basic_filebuf>::close()", - "std::__u::basic_ios>::imbue(std::__u::locale const&)", - "std::__u::basic_filebuf>::basic_filebuf()", - "std::__u::basic_filebuf>::open(char const*, unsigned int)", - "std::__u::basic_streambuf>::basic_streambuf()", - "std::__u::basic_string, std::__u::allocator>::~basic_string()", - "std::__u::basic_stringstream, std::__u::allocator>::~basic_stringstream()", - "std::__u::basic_streambuf>::~basic_streambuf()", - "std::__u::basic_iostream>::~basic_iostream()", - "std::__u::basic_ios>::~basic_ios()", - "std::__u::ios_base::init(void*)", - "std::__u::basic_ostream>::put(wchar_t)", - "std::__u::basic_ostream>::put(char)", - "std::__u::basic_ostream>& std::__u::__put_character_sequence>(std::__u::basic_ostream>&, char const*, unsigned long)", - "std::__u::basic_ostream>& std::__u::operator<<>(std::__u::basic_ostream>&, char const*)", - "std::__u::basic_ostream>& std::__u::operator<<>(std::__u::basic_ostream>&, char)", - "std::__u::basic_ostream>::sentry::sentry(std::__u::basic_ostream>&)", - "std::__u::basic_ostream>::flush()", - "std::__u::basic_ostream>::sentry::sentry(std::__u::basic_ostream>&)", - - "std::__u::locale::~locale()", - "std::__u::locale::operator=(std::__u::locale const&)", - "std::__u::locale::locale(std::__u::locale const&)", - "std::__u::locale::locale()", - "std::__u::locale::global(std::__u::locale const&)", - "std::__u::locale::locale(char const*)", - "std::__u::ios_base::imbue(std::__u::locale const&)", - "std::__u::locale::use_facet(std::__u::locale::id&) const", - "std::__u::ios_base::getloc() const", - "std::__u::ios_base::clear(unsigned int)", - - "std::basic_ostream>::basic_ostream(std::basic_streambuf>*)", - "std::basic_ostream>::flush()", - "std::basic_ostream>& std::__ostream_insert >(std::basic_ostream >&)", - "std::basic_ostream>::put(char)", - "std::basic_ostream>::~basic_ostream()", - - "std::basic_filebuf>::basic_filebuf()", - "std::basic_filebuf>::open(char const*, std::_Ios_Openmode)", - "std::basic_filebuf>::close()", - "std::basic_filebuf>::~basic_filebuf()", - - "std::__detail::_Prime_rehash_policy::_M_need_rehash(unsigned long, unsigned long, unsigned long) const", - - "std::basic_streambuf >::xsputn(char const*, long)", - - "std::__cxx11::basic_ostringstream, std::allocator>::basic_ostringstream()", - "std::__cxx11::basic_ostringstream, std::allocator>::str() const", - "std::__cxx11::basic_ostringstream, std::allocator>::~basic_ostringstream()", - - "std::basic_ios >::init(std::basic_streambuf >*)", - "std::basic_ios>::clear(std::_Ios_Iostate)", - "std::basic_ios>::operator bool() const", - "std::basic_ios>::operator!() const", - "std::basic_ios>::imbue(std::locale const&)", - - "std::_Hash_bytes(void const*, unsigned long, unsigned long)", - "unsigned long std::__1::__do_string_hash(char const*, char const*)", - "std::__1::hash::operator()(char const*) const", - - "std::allocator::allocator()", - "std::allocator::~allocator()", - - "std::basic_ifstream>::is_open()", - - "std::basic_ofstream>::basic_ofstream(char const*, std::_Ios_Openmode)", - "std::basic_ofstream>::is_open()", - "std::basic_ofstream>::close()", - "std::basic_ofstream>::~basic_ofstream()", - - "std::__cxx11::basic_stringstream, std::allocator>::basic_stringstream(std::__cxx11::basic_string, std::allocator> const&, std::_Ios_Openmode)", - "std::__cxx11::basic_stringstream, std::allocator>::~basic_stringstream()", - "std::basic_ostream>::put(wchar_t)", - - "std::__cxx11::basic_string, std::allocator>::basic_string(char const*, std::allocator const&)", - "std::__cxx11::basic_string, std::allocator>::basic_string(std::__cxx11::basic_string, std::allocator>&&)", - "std::__cxx11::basic_string, std::allocator>::_M_construct(unsigned long, char)", - "std::__cxx11::basic_string, std::allocator>::_M_append(char const*, unsigned long)", - "std::__cxx11::basic_string, std::allocator>::_M_assign(std::__cxx11::basic_string, std::allocator> const&)", - "std::__cxx11::basic_string, std::allocator>::_M_replace(unsigned long, unsigned long, char const*, unsigned long)", - "std::__cxx11::basic_string, std::allocator>::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)", - "std::__cxx11::basic_string, std::allocator>::length() const", - "std::__cxx11::basic_string, std::allocator>::data() const", - "std::__cxx11::basic_string, std::allocator>::size() const", - "std::__cxx11::basic_string, std::allocator>::c_str() const", - "std::__cxx11::basic_string, std::allocator>::~basic_string()", - "std::__cxx11::basic_string, std::allocator>::compare(char const*) const", - "std::__cxx11::basic_string, std::allocator>::compare(std::__cxx11::basic_string, std::allocator> const&) const", - "std::__cxx11::basic_string, std::allocator>::reserve(unsigned long)", - - "std::__cxx11::basic_string, std::allocator>::~basic_string()", - "std::__cxx11::basic_stringbuf, std::allocator>::overflow(int)", - "std::__cxx11::basic_stringbuf, std::allocator>::pbackfail(int)", - "std::__cxx11::basic_stringbuf, std::allocator>::underflow()", - "std::__cxx11::basic_stringbuf, std::allocator>::_M_sync(char*, unsigned long, unsigned long)", - "std::__cxx11::basic_stringbuf, std::allocator>::basic_stringbuf(std::__cxx11::basic_string, std::allocator> const&, std::_Ios_Openmode)", - - "std::basic_streambuf>::pubsync()", - "std::basic_ifstream>::close()", - "std::istream::ignore()", - "std::basic_ifstream>::basic_ifstream()", - "std::basic_ifstream>::basic_ifstream(char const*, std::_Ios_Openmode)", - "std::basic_ifstream>::~basic_ifstream()", - "std::basic_ifstream>::rdbuf() const", - "std::__basic_file::is_open() const", - "std::__basic_file::~__basic_file()", - - "std::ostream::flush()", - "std::basic_streambuf>::xsgetn(char*, long)", - - "std::locale::locale(char const*)", - "std::locale::global(std::locale const&)", - "std::locale::~locale()", - "std::ios_base::ios_base()", - "std::ios_base::~ios_base()", - - // libc++ - "std::__1::basic_string, std::__1::allocator>::basic_string(std::__1::basic_string, std::__1::allocator> const&)", - "std::__1::basic_string, std::__1::allocator>::~basic_string()", - "std::__1::basic_string, std::__1::allocator>::__init(char const*, unsigned long)", - "std::__1::basic_string, std::__1::allocator>::append(char const*, unsigned long)", - "std::__1::basic_string, std::__1::allocator>::data() const", - "std::__1::basic_ostream>::sentry::sentry(std::__1::basic_ostream>&)", - "std::__1::basic_ostream>::sentry::~sentry()", - "std::__1::basic_ostream>::flush()", - "std::__1::ios_base::__set_badbit_and_consider_rethrow()", - "char* std::__1::addressof(char&)", - "char const* std::__1::addressof(char const&)", - "std::__1::random_device::operator()()", - - "std::__1::locale::~locale()", - "std::__1::locale::use_facet(std::__1::locale::id&) const", - "std::__1::ios_base::ios_base()", - "std::__1::ios_base::getloc() const", - "std::__1::ios_base::clear(unsigned int)", - "std::__1::basic_iostream>::~basic_iostream()", - "std::__1::basic_ios>::~basic_ios()", - "std::__1::basic_streambuf>::basic_streambuf()", - "std::__1::basic_streambuf>::~basic_streambuf()", - "std::__1::basic_streambuf>::imbue(std::__1::locale const&)", - "std::__1::basic_streambuf>::setbuf(char*, long)", - "std::__1::basic_streambuf>::sync()", - "std::__1::basic_streambuf>::showmanyc()", - "std::__1::basic_streambuf>::xsgetn(char*, long)", - "std::__1::basic_streambuf>::uflow()", - "std::__1::basic_filebuf>::basic_filebuf()", - "std::__1::basic_filebuf>::~basic_filebuf()", - "std::__1::basic_filebuf>::open(char const*, unsigned int)", - "std::__1::basic_filebuf>::close()", - "std::__1::basic_filebuf>::sync()", - "std::__1::basic_istream>::~basic_istream()", - "virtual thunk to std::__1::basic_istream>::~basic_istream()", - "virtual thunk to std::__1::basic_ostream>::~basic_ostream()", - "std::__1::basic_ifstream>::~basic_ifstream()", - "std::__1::ios_base::init(void*)", - "std::__1::basic_istream>::read(char*, long)", - "std::__1::basic_ostream>::~basic_ostream()", - "std::__1::basic_string, std::__1::allocator>::__init(unsigned long, char)", - "std::__1::basic_ostream>::write(char const*, long)", - }; - const char* NoFreeDemanglesStartsWith[] = { - "std::__u::basic_streambuf>::sputn", - "std::__u::basic_streambuf>::pubsetbuf", - "std::__u::basic_istream>::read", - "std::__u::basic_string, std::__u::allocator>::resize", - "std::__u::basic_string, std::__u::allocator>& std::__u::basic_string, std::__u::allocator>::__assign_no_alias", - "std::__u::basic_string, std::__u::allocator>::__init", - "std::__u::basic_stringbuf, std::__u::allocator>::str", - "std::__u::basic_istream>::operator>>", - "std::__u::basic_istream>::ignore", - "std::__u::basic_istream>::get", - "std::__u::basic_ostream>::operator<<", - "std::__u::basic_ostream>::operator<<", - "std::__u::basic_ostream>& std::__u::operator<<", - "std::__1::basic_ostream>::operator<<", - "std::__1::ios_base::imbue", - "std::__1::basic_streambuf>::pubimbue", - "std::__1::basic_stringbuf, std::__1::allocator>::__init_buf_ptrs", - "std::__1::basic_stringbuf, std::__1::allocator>::basic_stringbuf", - "std::__1::basic_string, std::__1::allocator>::operator=", - "std::__1::ctype::widen", - "std::__1::basic_streambuf>::sputn", - "std::basic_ostream>& std::flush", - "std::basic_ostream>& std::operator<<", - "std::basic_ostream>& std::basic_ostream>::_M_insert", - "std::basic_ostream>& std::__ostream_insert>", - "std::basic_ostream>& std::operator<<", - "std::basic_ostream>::operator<<", - "std::basic_ostream>& std::basic_ostream>::_M_insert", - "std::istream::get", - "std::ostream::put", - "std::ostream::write", - "std::ostream& std::ostream::_M_insert", - "std::istream::read", - "std::istream::operator>>", - "std::basic_streambuf>::pubsetbuf", - "std::basic_streambuf>::sputn", - "std::istream& std::istream::_M_extract", - "std::ctype::widen", - //Rust - "std::io::stdio::_eprint", - }; - - StringSet<> NoFrees = {"mpfr_greater_p", - "vprintf", - "fprintf", - "fputc", - "memchr", - "time", - "strlen", - "__cxa_begin_catch", - "__cxa_guard_acquire", - "__cxa_guard_release", - "__cxa_end_catch", - "compress2", - "malloc_usable_size", - "MPI_Allreduce", - "lgamma", - "lgamma_r", - "__assertfail", - "__kmpc_global_thread_num", - "nlopt_force_stop", - "cudaRuntimeGetVersion" - }; - // clang-format on - - if (startsWith(F->getName(), "_ZNSolsE") || NoFrees.count(F->getName())) - return F; - - std::string demangledName = llvm::demangle(F->getName().str()); - // replace all '> >' with '>>' - size_t start = 0; - while ((start = demangledName.find("> >", start)) != std::string::npos) { - demangledName.replace(start, 3, ">>"); - } - if (NoFreeDemangles.count(demangledName)) - return F; - - for (auto Name : NoFreeDemanglesStartsWith) - if (startsWith(demangledName, Name)) - return F; - - switch (F->getIntrinsicID()) { - case Intrinsic::lifetime_start: - case Intrinsic::lifetime_end: - case Intrinsic::memcpy: - case Intrinsic::memmove: - case Intrinsic::memset: - case Intrinsic::cttz: - case Intrinsic::ctlz: - return F; - default:; - } - - { - Intrinsic::ID ID = Intrinsic::not_intrinsic; - if (isMemFreeLibMFunction(getFuncName(F), &ID)) - return F; - } - - if (F->empty()) { - if (EnzymeAssumeUnknownNoFree) { - return F; - } - if (EnzymeEmptyFnInactive) { - return F; - } - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No create nofree of empty function (" << demangledName << ") " - << F->getName() << ")\n"; - if (context.req) { - ss << " at context: " << *context.req; - if (auto CB = dyn_cast(context.req)) { - if (auto F = CB->getCalledFunction()) { - std::string demangleF = llvm::demangle(F->getName().str()); - // replace all '> >' with '>>' - size_t start = 0; - while ((start = demangleF.find("> >", start)) != std::string::npos) { - demangleF.replace(start, 3, ">>"); - } - ss << " (" << demangleF << ")"; - } - } - } else { - ss << *F << "\n"; - } - if (EmitNoDerivativeError(ss.str(), F, context)) { - return F; - } - llvm::errs() << " unhandled, create no free of empty function: " << *F - << "\n"; - llvm_unreachable("unhandled, create no free"); - } - - Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(), - "nofree_" + F->getName(), F->getParent()); - NewF->setAttributes(F->getAttributes()); - NewF->addAttribute(AttributeList::FunctionIndex, - Attribute::get(NewF->getContext(), Attribute::NoFree)); - - NoFreeCachedFunctions[F] = NewF; - - ValueToValueMapTy VMap; - - for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) { - VMap[i] = j; - j->setName(i->getName()); - ++j; - ++i; - } - - SmallVector Returns; - CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); - - NewF->setVisibility(llvm::GlobalValue::DefaultVisibility); - NewF->setLinkage(llvm::GlobalValue::InternalLinkage); - - const SmallPtrSet guaranteedUnreachable = - getGuaranteedUnreachable(NewF); - - SmallVector toErase; - for (BasicBlock &BB : *NewF) { - if (guaranteedUnreachable.count(&BB)) - continue; - for (Instruction &I : BB) { - StringRef funcName = ""; - if (auto CI = dyn_cast(&I)) { - if (CI->hasFnAttr(Attribute::NoFree)) - continue; - funcName = getFuncNameFromCall(CI); - } - if (auto CI = dyn_cast(&I)) { - if (CI->hasFnAttr(Attribute::NoFree)) - continue; - funcName = getFuncNameFromCall(CI); - } - if (isDeallocationFunction(funcName, TLI)) - toErase.push_back(&I); - else { - if (auto CI = dyn_cast(&I)) { - auto callval = CI->getCalledOperand(); - CI->setCalledOperand(CreateNoFree(context, callval)); - } - if (auto CI = dyn_cast(&I)) { - auto callval = CI->getCalledOperand(); - CI->setCalledOperand(CreateNoFree(context, callval)); - } - } - } - } - NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - - if (llvm::verifyFunction(*NewF, &llvm::errs())) { - llvm::errs() << *F << "\n"; - llvm::errs() << *NewF << "\n"; - report_fatal_error("function failed verification (4)"); - } - - for (auto E : toErase) { - E->eraseFromParent(); - } - - return NewF; -} - -void EnzymeLogic::clear() { - PPC.clear(); - AugmentedCachedFunctions.clear(); - ReverseCachedFunctions.clear(); - NoFreeCachedFunctions.clear(); - ForwardCachedFunctions.clear(); - BatchCachedFunctions.clear(); -} diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h deleted file mode 100644 index f9610ba49661..000000000000 --- a/enzyme/Enzyme/EnzymeLogic.h +++ /dev/null @@ -1,781 +0,0 @@ -//===- EnzymeLogic.h - Implementation of forward and reverse pass generation==// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file declares two functions CreatePrimalAndGradient and -// CreateAugmentedPrimal. CreatePrimalAndGradient takes a function, known -// TypeResults of the calling context, known activity analysis of the -// arguments and a bool `topLevel`. It creates a corresponding gradient -// function, computing the forward pass as well if at `topLevel`. -// CreateAugmentedPrimal takes similar arguments and creates an augmented -// forward pass. -// -//===----------------------------------------------------------------------===// -#ifndef ENZYME_LOGIC_H -#define ENZYME_LOGIC_H - -#include -#include -#include - -#include "llvm/IR/Function.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/CommandLine.h" - -#include "llvm/Analysis/AliasAnalysis.h" - -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/ErrorHandling.h" - -#include "ActivityAnalysis.h" -#include "FunctionUtils.h" -#include "TraceUtils.h" -#include "TypeAnalysis/TypeAnalysis.h" -#include "Utils.h" - -extern "C" { -extern llvm::cl::opt EnzymePrint; -} - -constexpr char EnzymeFPRTPrefix[] = "__enzyme_fprt_"; -constexpr char EnzymeFPRTOriginalPrefix[] = "__enzyme_fprt_original_"; - -enum class AugmentedStruct { Tape, Return, DifferentialReturn }; - -static inline std::string str(AugmentedStruct c) { - switch (c) { - case AugmentedStruct::Tape: - return "tape"; - case AugmentedStruct::Return: - return "return"; - case AugmentedStruct::DifferentialReturn: - return "DifferentialReturn"; - default: - llvm_unreachable("unknown cache type"); - } -} - -static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &o, - AugmentedStruct c) { - return o << str(c); -} - -enum class CacheType { Self, Shadow, Tape }; - -static inline std::string str(CacheType c) { - switch (c) { - case CacheType::Self: - return "self"; - case CacheType::Shadow: - return "shadow"; - case CacheType::Tape: - return "tape"; - default: - llvm_unreachable("unknown cache type"); - } -} - -static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &o, CacheType c) { - return o << str(c); -} - -//! return structtype if recursive function -class AugmentedReturn { -public: - llvm::Function *fn; - //! return structtype if recursive function - llvm::Type *tapeType; - - std::map, int> tapeIndices; - - //! Map from original call to sub augmentation data - std::map subaugmentations; - - //! Map from information desired from a augmented return to its index in the - //! returned struct - std::map returns; - - std::map>> - overwritten_args_map; - - std::map can_modref_map; - - std::set tapeIndiciesToFree; - - const std::vector constant_args; - - bool shadowReturnUsed; - - bool isComplete; - - AugmentedReturn( - llvm::Function *fn, llvm::Type *tapeType, - std::map, int> tapeIndices, - std::map returns, - std::map>> - overwritten_args_map, - std::map can_modref_map, - const std::vector &constant_args, bool shadowReturnUsed) - : fn(fn), tapeType(tapeType), tapeIndices(tapeIndices), returns(returns), - overwritten_args_map(overwritten_args_map), - can_modref_map(can_modref_map), constant_args(constant_args), - shadowReturnUsed(shadowReturnUsed), isComplete(false) {} -}; - -/// \p todiff is the function to differentiate -/// \p retType is the activity info of the return. -/// Only allowed to be DUP_ARG or CONSTANT. DUP_NONEED is not allowed, -/// set returnValue to false instead. -/// \p constant_args is the activity info of the arguments -/// \p subsequent_calls_may_write denotes whether some followup call may -/// write to accessible memory (and thus can potentially overwrite a load -/// made in this function). -/// \p overwritten_args marks whether an argument may be overwritten -/// before loads in the generated function (and thus cannot be cached). -/// \p returnValue is whether the primal's return should also be returned. -/// \p dretUsed is whether the shadow return value should also be returned. -/// Only allowed to be true if retType is CDIFFE_TYPE::DUP_ARG. -/// \p additionalArg is the type (or null) of an additional type in the -/// signature to hold the tape. -/// \p typeInfo is the type info information about the calling context -/// \p AtomicAdd is whether to perform all adjoint -/// updates to memory in an atomic way -struct ReverseCacheKey { - llvm::Function *todiff; - DIFFE_TYPE retType; - const std::vector constant_args; - bool subsequent_calls_may_write; - std::vector overwritten_args; - bool returnUsed; - bool shadowReturnUsed; - DerivativeMode mode; - unsigned width; - bool freeMemory; - bool AtomicAdd; - llvm::Type *additionalType; - bool forceAnonymousTape; - const FnTypeInfo typeInfo; - bool runtimeActivity; - - ReverseCacheKey replaceTypeInfo(const FnTypeInfo &newTypeInfo) const { - return {todiff, - retType, - constant_args, - subsequent_calls_may_write, - overwritten_args, - returnUsed, - shadowReturnUsed, - mode, - width, - freeMemory, - AtomicAdd, - additionalType, - forceAnonymousTape, - newTypeInfo, - runtimeActivity}; - } - /* - inline bool operator==(const ReverseCacheKey& rhs) const { - return todiff == rhs.todiff && - retType == rhs.retType && - constant_args == rhs.constant_args && - overwritten_args == rhs.overwritten_args && - returnUsed == rhs.returnUsed && - shadowReturnUsed == rhs.shadowReturnUsed && - mode == rhs.mode && - freeMemory == rhs.freeMemory && - AtomicAdd == rhs.AtomicAdd && - additionalType == rhs.additionalType && - typeInfo == rhs.typeInfo; - } - */ - - inline bool operator<(const ReverseCacheKey &rhs) const { - if (todiff < rhs.todiff) - return true; - if (rhs.todiff < todiff) - return false; - - if (retType < rhs.retType) - return true; - if (rhs.retType < retType) - return false; - - if (std::lexicographical_compare(constant_args.begin(), constant_args.end(), - rhs.constant_args.begin(), - rhs.constant_args.end())) - return true; - if (std::lexicographical_compare( - rhs.constant_args.begin(), rhs.constant_args.end(), - constant_args.begin(), constant_args.end())) - return false; - - if (subsequent_calls_may_write < rhs.subsequent_calls_may_write) - return true; - if (rhs.subsequent_calls_may_write < subsequent_calls_may_write) - return false; - - if (std::lexicographical_compare( - overwritten_args.begin(), overwritten_args.end(), - rhs.overwritten_args.begin(), rhs.overwritten_args.end())) - return true; - if (std::lexicographical_compare( - rhs.overwritten_args.begin(), rhs.overwritten_args.end(), - overwritten_args.begin(), overwritten_args.end())) - return false; - - if (returnUsed < rhs.returnUsed) - return true; - if (rhs.returnUsed < returnUsed) - return false; - - if (shadowReturnUsed < rhs.shadowReturnUsed) - return true; - if (rhs.shadowReturnUsed < shadowReturnUsed) - return false; - - if (mode < rhs.mode) - return true; - if (rhs.mode < mode) - return false; - - if (width < rhs.width) - return true; - if (rhs.width < width) - return false; - - if (freeMemory < rhs.freeMemory) - return true; - if (rhs.freeMemory < freeMemory) - return false; - - if (AtomicAdd < rhs.AtomicAdd) - return true; - if (rhs.AtomicAdd < AtomicAdd) - return false; - - if (additionalType < rhs.additionalType) - return true; - if (rhs.additionalType < additionalType) - return false; - - if (forceAnonymousTape < rhs.forceAnonymousTape) - return true; - if (rhs.forceAnonymousTape < forceAnonymousTape) - return false; - - if (typeInfo < rhs.typeInfo) - return true; - if (rhs.typeInfo < typeInfo) - return false; - - if (runtimeActivity < rhs.runtimeActivity) - return true; - if (rhs.runtimeActivity < runtimeActivity) - return false; - - // equal - return false; - } -}; - -// Holder class to represent a context in which a derivative -// or batch is being requested. This contains the instruction -// (or null) that led to the request, and a builder (or null) -// of the insertion point for code. -struct RequestContext { - llvm::Instruction *req; - llvm::IRBuilder<> *ip; - RequestContext(llvm::Instruction *req = nullptr, - llvm::IRBuilder<> *ip = nullptr) - : req(req), ip(ip) {} -}; - -[[maybe_unused]] static llvm::Type * -getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { - switch (width) { - default: - if (builtinFloat) - llvm::report_fatal_error("Invalid float width requested"); - else - llvm::report_fatal_error( - "Truncation to non builtin float width unsupported"); - case 64: - return llvm::Type::getDoubleTy(ctx); - case 32: - return llvm::Type::getFloatTy(ctx); - case 16: - return llvm::Type::getHalfTy(ctx); - } -} - -enum TruncateMode { - TruncMemMode = 0b0001, - TruncOpMode = 0b0010, - TruncOpFullModuleMode = 0b0110, -}; -[[maybe_unused]] static const char *truncateModeStr(TruncateMode mode) { - switch (mode) { - case TruncMemMode: - return "mem"; - case TruncOpMode: - return "op"; - case TruncOpFullModuleMode: - return "op_full_module"; - } - llvm_unreachable("Invalid truncation mode"); -} - -struct FloatRepresentation { - // |_|__________|_________________| - // ^ ^ ^ - // sign bit exponent significand - // - // value = (sign) * significand * 2 ^ exponent - unsigned exponentWidth; - unsigned significandWidth; - - FloatRepresentation(unsigned e, unsigned s) - : exponentWidth(e), significandWidth(s) {} - - unsigned getTypeWidth() const { return 1 + exponentWidth + significandWidth; } - - bool canBeBuiltin() const { - unsigned w = getTypeWidth(); - return (w == 16 && significandWidth == 10) || - (w == 32 && significandWidth == 23) || - (w == 64 && significandWidth == 52); - } - - llvm::Type *getBuiltinType(llvm::LLVMContext &ctx) const { - if (!canBeBuiltin()) - return nullptr; - return getTypeForWidth(ctx, getTypeWidth(), /*builtinFloat=*/true); - } - - llvm::Type *getType(llvm::LLVMContext &ctx) const { - llvm::Type *builtinType = getBuiltinType(ctx); - if (builtinType) - return builtinType; - llvm_unreachable("TODO MPFR"); - } - - bool operator==(const FloatRepresentation &other) const { - return other.exponentWidth == exponentWidth && - other.significandWidth == significandWidth; - } - bool operator<(const FloatRepresentation &other) const { - return std::tuple(exponentWidth, significandWidth) < - std::tuple(other.exponentWidth, other.significandWidth); - } - std::string to_string() const { - return std::to_string(getTypeWidth()) + "_" + - std::to_string(significandWidth); - } -}; - -struct FloatTruncation { -private: - FloatRepresentation from, to; - TruncateMode mode; - -public: - FloatTruncation(FloatRepresentation From, FloatRepresentation To, - TruncateMode mode) - : from(From), to(To), mode(mode) { - if (!From.canBeBuiltin()) - llvm::report_fatal_error("Float truncation `from` type is not builtin."); - if (From.exponentWidth < To.exponentWidth && - (mode == TruncOpMode || mode == TruncOpFullModuleMode)) - llvm::report_fatal_error("Float truncation `from` type must have " - "a wider exponent than `to`."); - if (From.significandWidth < To.significandWidth && - (mode == TruncOpMode || mode == TruncOpFullModuleMode)) - llvm::report_fatal_error("Float truncation `from` type must have " - "a wider significand than `to`."); - if (From == To) - llvm::report_fatal_error( - "Float truncation `from` and `to` type must not be the same."); - } - TruncateMode getMode() { return mode; } - FloatRepresentation getTo() { return to; } - unsigned getFromTypeWidth() { return from.getTypeWidth(); } - unsigned getToTypeWidth() { return to.getTypeWidth(); } - llvm::Type *getFromType(llvm::LLVMContext &ctx) { - return from.getBuiltinType(ctx); - } - bool isToFPRT() { - // TODO maybe add new mode in which we directly truncate to native fp ops, - // for now everything goes through the runtime - return true; - } - llvm::Type *getToType(llvm::LLVMContext &ctx) { return getFromType(ctx); } - auto getTuple() const { return std::tuple(from, to, mode); } - bool operator==(const FloatTruncation &other) const { - return getTuple() == other.getTuple(); - } - bool operator<(const FloatTruncation &other) const { - return getTuple() < other.getTuple(); - } - std::string mangleTruncation() const { - return from.to_string() + "to" + to.to_string(); - } - std::string mangleFrom() const { return from.to_string(); } -}; - -class EnzymeLogic { -public: - PreProcessCache PPC; - - /// \p PostOpt is whether to perform basic - /// optimization of the function after synthesis - bool PostOpt; - - EnzymeLogic(bool PostOpt) : PostOpt(PostOpt) {} - - struct AugmentedCacheKey { - llvm::Function *fn; - DIFFE_TYPE retType; - const std::vector constant_args; - bool subsequent_calls_may_write; - std::vector overwritten_args; - bool returnUsed; - bool shadowReturnUsed; - const FnTypeInfo typeInfo; - bool freeMemory; - bool AtomicAdd; - bool omp; - unsigned width; - bool runtimeActivity; - - inline bool operator<(const AugmentedCacheKey &rhs) const { - if (fn < rhs.fn) - return true; - if (rhs.fn < fn) - return false; - - if (retType < rhs.retType) - return true; - if (rhs.retType < retType) - return false; - - if (std::lexicographical_compare( - constant_args.begin(), constant_args.end(), - rhs.constant_args.begin(), rhs.constant_args.end())) - return true; - if (std::lexicographical_compare( - rhs.constant_args.begin(), rhs.constant_args.end(), - constant_args.begin(), constant_args.end())) - return false; - - if (subsequent_calls_may_write < rhs.subsequent_calls_may_write) - return true; - if (rhs.subsequent_calls_may_write < subsequent_calls_may_write) - return false; - - if (std::lexicographical_compare( - overwritten_args.begin(), overwritten_args.end(), - rhs.overwritten_args.begin(), rhs.overwritten_args.end())) - return true; - if (std::lexicographical_compare( - rhs.overwritten_args.begin(), rhs.overwritten_args.end(), - overwritten_args.begin(), overwritten_args.end())) - return false; - - if (returnUsed < rhs.returnUsed) - return true; - if (rhs.returnUsed < returnUsed) - return false; - - if (shadowReturnUsed < rhs.shadowReturnUsed) - return true; - if (rhs.shadowReturnUsed < shadowReturnUsed) - return false; - - if (freeMemory < rhs.freeMemory) - return true; - if (rhs.freeMemory < freeMemory) - return false; - - if (AtomicAdd < rhs.AtomicAdd) - return true; - if (rhs.AtomicAdd < AtomicAdd) - return false; - - if (omp < rhs.omp) - return true; - if (rhs.omp < omp) - return false; - - if (typeInfo < rhs.typeInfo) - return true; - if (rhs.typeInfo < typeInfo) - return false; - - if (width < rhs.width) - return true; - if (rhs.width < width) - return false; - - if (runtimeActivity < rhs.runtimeActivity) - return true; - if (rhs.runtimeActivity < runtimeActivity) - return false; - - // equal - return false; - } - }; - - std::map NoFreeCachedFunctions; - llvm::Function *CreateNoFree(RequestContext context, llvm::Function *todiff); - llvm::Value *CreateNoFree(RequestContext context, llvm::Value *todiff); - - std::map AugmentedCachedFunctions; - - /// Create an augmented forward pass. - /// \p context the instruction which requested this derivative (or null). - /// \p todiff is the function to differentiate - /// \p retType is the activity info of the return - /// \p constant_args is the activity info of the arguments - /// \p returnUsed is whether the primal's return should also be returned - /// \p typeInfo is the type info information about the calling context - /// \p subsequent_calls_may_write denotes whether an instruction between - /// forward and reverse - /// may write to memory potentially read by this function. - /// \p _overwritten_args marks whether an argument may be rewritten before - /// loads in the generated function (and thus cannot be cached). - /// \p forceAnonymousTape forces the tape to be an i8* rather than the true - /// tape structure - /// \p AtomicAdd is whether to perform all adjoint updates to - /// memory in an atomic way - const AugmentedReturn &CreateAugmentedPrimal( - RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, - llvm::ArrayRef constant_args, TypeAnalysis &TA, - bool returnUsed, bool shadowReturnUsed, const FnTypeInfo &typeInfo, - bool subsequent_calls_may_write, - const std::vector _overwritten_args, bool forceAnonymousTape, - bool runtimeActivity, unsigned width, bool AtomicAdd, bool omp = false); - - std::map ReverseCachedFunctions; - - struct ForwardCacheKey { - llvm::Function *todiff; - DIFFE_TYPE retType; - const std::vector constant_args; - bool subsequent_calls_may_write; - std::vector overwritten_args; - bool returnUsed; - DerivativeMode mode; - unsigned width; - llvm::Type *additionalType; - const FnTypeInfo typeInfo; - bool runtimeActivity; - - inline bool operator<(const ForwardCacheKey &rhs) const { - if (todiff < rhs.todiff) - return true; - if (rhs.todiff < todiff) - return false; - - if (retType < rhs.retType) - return true; - if (rhs.retType < retType) - return false; - - if (std::lexicographical_compare( - constant_args.begin(), constant_args.end(), - rhs.constant_args.begin(), rhs.constant_args.end())) - return true; - if (std::lexicographical_compare( - rhs.constant_args.begin(), rhs.constant_args.end(), - constant_args.begin(), constant_args.end())) - return false; - - if (subsequent_calls_may_write < rhs.subsequent_calls_may_write) - return true; - if (rhs.subsequent_calls_may_write < subsequent_calls_may_write) - return false; - - if (std::lexicographical_compare( - overwritten_args.begin(), overwritten_args.end(), - rhs.overwritten_args.begin(), rhs.overwritten_args.end())) - return true; - if (std::lexicographical_compare( - rhs.overwritten_args.begin(), rhs.overwritten_args.end(), - overwritten_args.begin(), overwritten_args.end())) - return false; - - if (returnUsed < rhs.returnUsed) - return true; - if (rhs.returnUsed < returnUsed) - return false; - - if (mode < rhs.mode) - return true; - if (rhs.mode < mode) - return false; - - if (width < rhs.width) - return true; - if (rhs.width < width) - return false; - - if (additionalType < rhs.additionalType) - return true; - if (rhs.additionalType < additionalType) - return false; - - if (typeInfo < rhs.typeInfo) - return true; - if (rhs.typeInfo < typeInfo) - return false; - - if (runtimeActivity < rhs.runtimeActivity) - return true; - if (rhs.runtimeActivity < runtimeActivity) - return false; - - // equal - return false; - } - }; - - std::map ForwardCachedFunctions; - - using BatchCacheKey = std::tuple, BATCH_TYPE>; - std::map BatchCachedFunctions; - - using TraceCacheKey = - std::tuple; - std::map TraceCachedFunctions; - - /// Create the reverse pass, or combined forward+reverse derivative function. - /// \p context the instruction which requested this derivative (or null). - /// \p augmented is the data structure created by prior call to an - /// augmented forward pass - llvm::Function *CreatePrimalAndGradient(RequestContext context, - const ReverseCacheKey &&key, - TypeAnalysis &TA, - const AugmentedReturn *augmented, - bool omp = false); - - /// Create the forward (or forward split) mode derivative function. - /// \p context the instruction which requested this derivative (or null). - /// \p todiff is the function to differentiate - /// \p retType is the activity info of the return - /// \p constant_args is the activity info of the arguments - /// \p TA is the type analysis results - /// \p returnValue is whether the primal's return should also be returned - /// \p mode is the requested derivative mode - /// \p is whether we should free memory allocated here (and could be - /// accessed externally). - /// \p width is the vector width requested. - /// \p additionalArg is the type (or null) of an additional type in the - /// signature to hold the tape. - /// \p FnTypeInfo is the known types of the argument and returns - /// \p subsequent_calls_may_write denotes whether an instruction between - /// forward and reverse - /// may write to memory potentially read by this function. - /// \p _overwritten_args marks whether an argument may be rewritten - /// before loads in the generated function (and thus cannot be cached). - /// \p augmented is the data structure created by prior call to an - /// augmented forward pass - /// \p omp is whether this function is an OpenMP closure body. - llvm::Function *CreateForwardDiff( - RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, - llvm::ArrayRef constant_args, TypeAnalysis &TA, - bool returnValue, DerivativeMode mode, bool freeMemory, - bool runtimeActivity, unsigned width, llvm::Type *additionalArg, - const FnTypeInfo &typeInfo, bool subsequent_calls_may_write, - const std::vector _overwritten_args, - const AugmentedReturn *augmented, bool omp = false); - - /// Create a function batched in its inputs. - /// \p context the instruction which requested this batch (or null). - /// \p tobatch is the function to batch - /// \p width is the vector width requested. - /// \p arg_types denotes which arguments are batched. - /// \p ret_type denotes whether to batch the return. - llvm::Function *CreateBatch(RequestContext context, llvm::Function *tobatch, - unsigned width, - llvm::ArrayRef arg_types, - BATCH_TYPE ret_type); - - using TruncateCacheKey = - std::tuple; - std::map TruncateCachedFunctions; - llvm::Function *CreateTruncateFunc(RequestContext context, - llvm::Function *tobatch, - FloatTruncation truncation, - TruncateMode mode); - bool CreateTruncateValue(RequestContext context, llvm::Value *addr, - FloatRepresentation from, FloatRepresentation to, - bool isTruncate); - - /// Create a traced version of a function - /// \p context the instruction which requested this trace (or null). - /// \p totrace is the function to trace - /// \p sampleFunctions is a set of the functions to sample - /// \p observeFunctions is a set of the functions to observe - /// \p ActiveRandomVariables is a set of which variables are active - /// \p mode is the mode to use - /// \p autodiff is whether to also differentiate - /// \p interface specifies the ABI to use. - llvm::Function * - CreateTrace(RequestContext context, llvm::Function *totrace, - const llvm::SmallPtrSetImpl &sampleFunctions, - const llvm::SmallPtrSetImpl &observeFunctions, - const llvm::StringSet<> &ActiveRandomVariables, ProbProgMode mode, - bool autodiff, TraceInterface *interface); - - void clear(); -}; - -extern "C" { -extern llvm::cl::opt looseTypeAnalysis; -extern llvm::cl::opt nonmarkedglobals_inactiveloads; -}; - -class GradientUtils; -bool shouldAugmentCall(llvm::CallInst *op, const GradientUtils *gutils); - -bool legalCombinedForwardReverse( - llvm::CallInst *origop, - const std::map &replacedReturns, - llvm::SmallVectorImpl &postCreate, - llvm::SmallVectorImpl &userReplace, - const GradientUtils *gutils, - const llvm::SmallPtrSetImpl - &unnecessaryInstructions, - const llvm::SmallPtrSetImpl &oldUnreachable, - const bool subretused); - -std::pair, - llvm::SmallVector> -getDefaultFunctionTypeForAugmentation(llvm::FunctionType *called, - bool returnUsed, DIFFE_TYPE retType); - -std::pair, - llvm::SmallVector> -getDefaultFunctionTypeForGradient(llvm::FunctionType *called, - DIFFE_TYPE retType); -#endif diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp deleted file mode 100644 index bc407ca09ba6..000000000000 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ /dev/null @@ -1,8167 +0,0 @@ -//===- FunctionUtils.cpp - Implementation of function utilities -----------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file defines utilities on LLVM Functions that are used as part of the AD -// process. -// -//===----------------------------------------------------------------------===// -#include "FunctionUtils.h" - -#include "DiffeGradientUtils.h" -#include "EnzymeLogic.h" -#include "GradientUtils.h" -#include "LibraryFuncs.h" - -#include "llvm/IR/Attributes.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Verifier.h" -#include "llvm/Passes/PassBuilder.h" - -#include "llvm/ADT/APSInt.h" -#include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/SetOperations.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/LazyValueInfo.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/MemoryDependenceAnalysis.h" -#include "llvm/Analysis/MemorySSA.h" -#include "llvm/Analysis/OptimizationRemarkEmitter.h" -#include - -#if LLVM_VERSION_MAJOR < 16 -#include "llvm/Analysis/CFLSteensAliasAnalysis.h" -#endif -#include "llvm/Analysis/DependenceAnalysis.h" -#include "llvm/Analysis/TypeBasedAliasAnalysis.h" -#include "llvm/CodeGen/UnreachableBlockElim.h" - -#include "llvm/Analysis/PhiValues.h" -#include "llvm/Analysis/ProfileSummaryInfo.h" -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScopedNoAliasAA.h" -#include "llvm/Analysis/TargetTransformInfo.h" - -#include "llvm/Support/TimeProfiler.h" - -#include "llvm/Transforms/IPO/FunctionAttrs.h" -#include "llvm/Transforms/Utils/Mem2Reg.h" - -#include "llvm/Transforms/Utils.h" - -#include "llvm/Transforms/InstCombine/InstCombine.h" -#include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h" -#include "llvm/Transforms/Scalar/DCE.h" -#include "llvm/Transforms/Scalar/DeadStoreElimination.h" -#include "llvm/Transforms/Scalar/EarlyCSE.h" -#include "llvm/Transforms/Scalar/GVN.h" -#include "llvm/Transforms/Scalar/IndVarSimplify.h" -#include "llvm/Transforms/Scalar/InstSimplifyPass.h" -#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" -#include "llvm/Transforms/Scalar/MemCpyOptimizer.h" -#include "llvm/Transforms/Scalar/SROA.h" -#include "llvm/Transforms/Scalar/SimplifyCFG.h" -#include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/LCSSA.h" -#include "llvm/Transforms/Utils/LowerInvoke.h" - -#include "llvm/Transforms/IPO/FunctionAttrs.h" -#include "llvm/Transforms/Scalar/DCE.h" -#include "llvm/Transforms/Scalar/LoopDeletion.h" -#include "llvm/Transforms/Scalar/LoopRotation.h" - -#include "llvm/Transforms/Utils/CodeExtractor.h" - -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" - -#include "llvm/IR/LegacyPassManager.h" -#if LLVM_VERSION_MAJOR <= 16 -#include "llvm/Transforms/IPO/PassManagerBuilder.h" -#endif -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" - -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" - -#include - -#include "CacheUtility.h" - -#define addAttribute addAttributeAtIndex -#define removeAttribute removeAttributeAtIndex -#define getAttribute getAttributeAtIndex -#define hasAttribute hasAttributeAtIndex - -#define DEBUG_TYPE "enzyme" -using namespace llvm; - -extern "C" { -cl::opt EnzymePreopt("enzyme-preopt", cl::init(true), cl::Hidden, - cl::desc("Run enzyme preprocessing optimizations")); - -cl::opt EnzymeInline("enzyme-inline", cl::init(false), cl::Hidden, - cl::desc("Force inlining of autodiff")); - -cl::opt EnzymeNoAlias("enzyme-noalias", cl::init(false), cl::Hidden, - cl::desc("Force noalias of autodiff")); -#if LLVM_VERSION_MAJOR < 16 -cl::opt - EnzymeAggressiveAA("enzyme-aggressive-aa", cl::init(false), cl::Hidden, - cl::desc("Use more unstable but aggressive LLVM AA")); -#endif -cl::opt EnzymeLowerGlobals( - "enzyme-lower-globals", cl::init(false), cl::Hidden, - cl::desc("Lower globals to locals assuming the global values are not " - "needed outside of this gradient")); - -cl::opt - EnzymeInlineCount("enzyme-inline-count", cl::init(10000), cl::Hidden, - cl::desc("Limit of number of functions to inline")); - -cl::opt EnzymeCoalese("enzyme-coalese", cl::init(false), cl::Hidden, - cl::desc("Whether to coalese memory allocations")); - -static cl::opt EnzymePHIRestructure( - "enzyme-phi-restructure", cl::init(false), cl::Hidden, - cl::desc("Whether to restructure phi's to have better unwrap behavior")); - -cl::opt - EnzymeNameInstructions("enzyme-name-instructions", cl::init(false), - cl::Hidden, - cl::desc("Have enzyme name all instructions")); - -cl::opt EnzymeSelectOpt("enzyme-select-opt", cl::init(true), cl::Hidden, - cl::desc("Run Enzyme select optimization")); - -cl::opt EnzymeAutoSparsity("enzyme-auto-sparsity", cl::init(false), - cl::Hidden, - cl::desc("Run Enzyme auto sparsity")); - -cl::opt EnzymePostOptLevel( - "enzyme-post-opt-level", cl::init(0), cl::Hidden, - cl::desc("Post optimization level within Enzyme differentiated function")); - -cl::opt EnzymeAlwaysInlineDiff( - "enzyme-always-inline", cl::init(false), cl::Hidden, - cl::desc("Mark generated functions as always-inline")); -} - -/// Is the use of value val as an argument of call CI potentially captured -bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val) { - Function *F = CI->getCalledFunction(); - - if (auto castinst = dyn_cast(CI->getCalledOperand())) { - if (castinst->isCast()) - if (auto fn = dyn_cast(castinst->getOperand(0))) { - F = fn; - } - } - - if (F == nullptr) - return true; - - if (F->getIntrinsicID() == Intrinsic::memset) - return false; - if (F->getIntrinsicID() == Intrinsic::memcpy) - return false; - if (F->getIntrinsicID() == Intrinsic::memmove) - return false; - - auto arg = F->arg_begin(); - for (size_t i = 0, size = CI->arg_size(); i < size; i++) { - if (val == CI->getArgOperand(i)) { - // This is a vararg, assume captured - if (arg == F->arg_end()) { - return true; - } else { - if (!arg->hasNoCaptureAttr()) { - return true; - } - } - } - if (arg != F->arg_end()) - arg++; - } - // No argument captured - return false; -} - -enum RecurType { - MaybeRecursive = 1, - NotRecursive = 2, - DefinitelyRecursive = 3, -}; -/// Return whether this function eventually calls itself -static bool -IsFunctionRecursive(Function *F, - std::map &Results) { - - // If we haven't seen this function before, look at all callers - // and mark this as potentially recursive. If we see this function - // still as marked as MaybeRecursive, we will definitionally have - // found an eventual caller of the original function. If not, - // the function does not eventually call itself (in a static way) - if (Results.find(F) == Results.end()) { - Results[F] = MaybeRecursive; // staging - for (auto &BB : *F) { - for (auto &I : BB) { - if (auto call = dyn_cast(&I)) { - if (call->getCalledFunction() == nullptr) - continue; - if (call->getCalledFunction()->empty()) - continue; - IsFunctionRecursive(call->getCalledFunction(), Results); - } - if (auto call = dyn_cast(&I)) { - if (call->getCalledFunction() == nullptr) - continue; - if (call->getCalledFunction()->empty()) - continue; - IsFunctionRecursive(call->getCalledFunction(), Results); - } - } - } - if (Results[F] == MaybeRecursive) { - Results[F] = NotRecursive; // not recursive - } - } else if (Results[F] == MaybeRecursive) { - Results[F] = DefinitelyRecursive; // definitely recursive - } - assert(Results[F] != MaybeRecursive); - return Results[F] == DefinitelyRecursive; -} - -static inline bool OnlyUsedInOMP(AllocaInst *AI) { - bool ompUse = false; - for (auto U : AI->users()) { - if (auto SI = dyn_cast(U)) - if (SI->getPointerOperand() == AI) - continue; - if (auto CI = dyn_cast(U)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "__kmpc_for_static_init_4" || - F->getName() == "__kmpc_for_static_init_4u" || - F->getName() == "__kmpc_for_static_init_8" || - F->getName() == "__kmpc_for_static_init_8u") { - ompUse = true; - } - } - } - } - - if (!ompUse) - return false; - return true; -} - -void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) { - SmallVector, 1> Todo; - for (auto U : AI->users()) { - Todo.push_back( - std::make_tuple((Value *)rep, (Value *)AI, cast(U))); - } - SmallVector toErase; - if (auto I = dyn_cast(AI)) - toErase.push_back(I); - SmallVector toPostCache; - while (Todo.size()) { - auto cur = Todo.back(); - Todo.pop_back(); - Value *rep = std::get<0>(cur); - Value *prev = std::get<1>(cur); - Value *inst = std::get<2>(cur); - if (auto ASC = dyn_cast(inst)) { - auto AS = cast(rep->getType())->getAddressSpace(); - if (AS == ASC->getDestAddressSpace()) { - ASC->replaceAllUsesWith(rep); - toErase.push_back(ASC); - continue; - } - ASC->setOperand(0, rep); - continue; - } - if (auto CI = dyn_cast(inst)) { - if (!CI->getType()->isPointerTy()) { - CI->setOperand(0, rep); - continue; - } - IRBuilder<> B(CI); - auto nCI0 = B.CreateCast( - CI->getOpcode(), rep, -#if LLVM_VERSION_MAJOR < 17 - PointerType::get(CI->getType()->getPointerElementType(), - cast(rep->getType())->getAddressSpace()) -#else - rep->getType() -#endif - ); - if (auto nCI = dyn_cast(nCI0)) - nCI->takeName(CI); - for (auto U : CI->users()) { - Todo.push_back( - std::make_tuple((Value *)nCI0, (Value *)CI, cast(U))); - } - toErase.push_back(CI); - continue; - } - if (auto GEP = dyn_cast(inst)) { - IRBuilder<> B(GEP); - SmallVector ind(GEP->indices()); - auto nGEP = cast( - B.CreateGEP(GEP->getSourceElementType(), rep, ind)); - nGEP->takeName(GEP); - for (auto U : GEP->users()) { - Todo.push_back( - std::make_tuple((Value *)nGEP, (Value *)GEP, cast(U))); - } - toErase.push_back(GEP); - continue; - } - if (auto P = dyn_cast(inst)) { - auto NumOperands = P->getNumIncomingValues(); - SmallVector replacedOperands(NumOperands, nullptr); - for (size_t i = 0; i < NumOperands; i++) - if (P->getOperand(i) == prev) - replacedOperands[i] = rep; - for (auto tval : Todo) { - if (std::get<2>(tval) != P) - continue; - for (size_t i = 0; i < NumOperands; i++) - if (P->getOperand(i) == std::get<1>(tval)) { - replacedOperands[i] = std::get<0>(tval); - } - } - bool allReplaced = true; - for (size_t i = 0; i < NumOperands; i++) { - if (!replacedOperands[i]) { - allReplaced = false; - } - } - if (!allReplaced) { - bool remainingArePHIs = true; - for (auto v : Todo) { - if (isa(std::get<2>(v))) { - } else { - remainingArePHIs = false; - } - } - if (!remainingArePHIs) { - Todo.insert(Todo.begin(), cur); - llvm::errs() << " continuing\n"; - continue; - } - } else { - IRBuilder<> B(&(*P->getParent()->getFirstNonPHIOrDbgOrLifetime())); - auto nP = B.CreatePHI(rep->getType(), P->getNumOperands()); - for (size_t i = 0; i < NumOperands; i++) { - nP->addIncoming(replacedOperands[i], P->getIncomingBlock(i)); - } - nP->takeName(P); - for (auto U : P->users()) { - Todo.push_back( - std::make_tuple((Value *)nP, (Value *)P, cast(U))); - } - toErase.push_back(P); - for (int i = Todo.size() - 1; i >= 0; i--) { - if (std::get<2>(Todo[i]) != P) - continue; - Todo.erase(Todo.begin() + i); - } - continue; - } - } - if (auto II = dyn_cast(inst)) { - if (isIntelSubscriptIntrinsic(*II)) { - - const std::array idxArgsIndices{{0, 1, 2, 4}}; - const size_t ptrArgIndex = 3; - - SmallVector args(5); - for (auto i : idxArgsIndices) { - Value *idx = II->getOperand(i); - args[i] = idx; - } - args[ptrArgIndex] = rep; - - IRBuilder<> B(II); - auto nII = cast(B.CreateCall(II->getCalledFunction(), args)); - // Must copy the elementtype attribute as it is needed by the intrinsic - nII->addParamAttr( - ptrArgIndex, - II->getParamAttr(ptrArgIndex, Attribute::AttrKind::ElementType)); - nII->takeName(II); - for (auto U : II->users()) { - Todo.push_back( - std::make_tuple((Value *)nII, (Value *)II, cast(U))); - } - toErase.push_back(II); - continue; - } - } - if (auto LI = dyn_cast(inst)) { - LI->setOperand(0, rep); - continue; - } - if (auto SI = dyn_cast(inst)) { - if (SI->getPointerOperand() == prev) { - SI->setOperand(1, rep); - toPostCache.push_back(SI); - continue; - } - } - if (auto MS = dyn_cast(inst)) { - IRBuilder<> B(MS); - - Value *nargs[] = {rep, MS->getArgOperand(1), MS->getArgOperand(2), - MS->getArgOperand(3)}; - Type *tys[] = {nargs[0]->getType(), nargs[2]->getType()}; - auto nMS = cast(B.CreateCall( - getIntrinsicDeclaration(MS->getParent()->getParent()->getParent(), - Intrinsic::memset, tys), - nargs)); - nMS->copyMetadata(*MS); - nMS->setAttributes(MS->getAttributes()); - toErase.push_back(MS); - continue; - } - if (auto MTI = dyn_cast(inst)) { - IRBuilder<> B(MTI); - - Value *nargs[4] = {MTI->getArgOperand(0), MTI->getArgOperand(1), - MTI->getArgOperand(2), MTI->getArgOperand(3)}; - - if (nargs[0] == prev) - nargs[0] = rep; - - if (nargs[1] == prev) - nargs[1] = rep; - - Type *tys[] = {nargs[0]->getType(), nargs[1]->getType(), - nargs[2]->getType()}; - - auto nMTI = cast(B.CreateCall( - getIntrinsicDeclaration(MTI->getParent()->getParent()->getParent(), - MTI->getIntrinsicID(), tys), - nargs)); - nMTI->copyMetadata(*MTI); - nMTI->setAttributes(MTI->getAttributes()); - toErase.push_back(MTI); - continue; - } - if (auto CI = dyn_cast(inst)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "julia.write_barrier" && legal) { - toErase.push_back(CI); - continue; - } - if (F->getName() == "julia.write_barrier_binding" && legal) { - toErase.push_back(CI); - continue; - } - } - IRBuilder<> B(CI); - auto Addr = B.CreateAddrSpaceCast(rep, prev->getType()); - for (size_t i = 0; i < CI->arg_size(); i++) { - if (CI->getArgOperand(i) == prev) { - CI->setArgOperand(i, Addr); - } - } - continue; - } - if (auto I = dyn_cast(inst)) - llvm::errs() << *I->getParent()->getParent() << "\n"; - llvm_unreachable("Illegal address space propagation"); - } - - for (auto I : llvm::reverse(toErase)) { - I->eraseFromParent(); - } - for (auto SI : toPostCache) { - IRBuilder<> B(SI->getNextNode()); - PostCacheStore(SI, B); - } -} - -/// Convert necessary stack allocations into mallocs for use in the reverse -/// pass. Specifically if we're not topLevel all allocations must be upgraded -/// Even if topLevel any allocations that aren't in the entry block (and -/// therefore may not be reachable in the reverse pass) must be upgraded. -static inline void -UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode, - SmallPtrSetImpl &Unreachable) { - SmallVector ToConvert; - - for (auto &BB : *NewF) { - if (Unreachable.count(&BB)) - continue; - for (auto &I : BB) { - if (auto AI = dyn_cast(&I)) { - bool UsableEverywhere = AI->getParent() == &NewF->getEntryBlock(); - // TODO use is_value_needed_in_reverse (requiring GradientUtils) - if (OnlyUsedInOMP(AI)) - continue; - if (!UsableEverywhere || mode != DerivativeMode::ReverseModeCombined) { - ToConvert.push_back(AI); - } - } - } - } - - for (auto AI : ToConvert) { - std::string nam = AI->getName().str(); - AI->setName(""); - - // Ensure we insert the malloc after the allocas - Instruction *insertBefore = AI; - while (isa(insertBefore->getNextNode())) { - insertBefore = insertBefore->getNextNode(); - assert(insertBefore); - } - - auto i64 = Type::getInt64Ty(NewF->getContext()); - IRBuilder<> B(insertBefore); - CallInst *CI = nullptr; - Instruction *ZeroInst = nullptr; - auto rep = CreateAllocation( - B, AI->getAllocatedType(), B.CreateZExtOrTrunc(AI->getArraySize(), i64), - nam, &CI, /*ZeroMem*/ EnzymeZeroCache ? &ZeroInst : nullptr); - auto align = AI->getAlign().value(); - CI->setMetadata( - "enzyme_fromstack", - MDNode::get(CI->getContext(), - {ConstantAsMetadata::get(ConstantInt::get( - IntegerType::get(AI->getContext(), 64), align))})); - - for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", - "enzymejl_allocart"}) - if (auto M = AI->getMetadata(MD)) - CI->setMetadata(MD, M); - - if (rep != CI) { - cast(rep)->setMetadata("enzyme_caststack", - MDNode::get(CI->getContext(), {})); - } - if (ZeroInst) { - ZeroInst->setMetadata("enzyme_zerostack", - MDNode::get(CI->getContext(), {})); - } - - auto PT0 = cast(rep->getType()); - auto PT1 = cast(AI->getType()); - if (PT0->getAddressSpace() != PT1->getAddressSpace()) { - RecursivelyReplaceAddressSpace(AI, rep, /*legal*/ false); - } else { - assert(rep->getType() == AI->getType()); - AI->replaceAllUsesWith(rep); - AI->eraseFromParent(); - } - } -} - -// Create a stack variable containing the size of the allocation -// error if not possible (e.g. not local) -static inline AllocaInst * -OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T, - const std::map &reallocSizes) { - IRBuilder<> B(&*NewF->getEntryBlock().begin()); - AllocaInst *AI = B.CreateAlloca(T); - - std::set> seen; - std::deque> todo = {{Ptr, Loc}}; - - while (todo.size()) { - auto next = todo.front(); - todo.pop_front(); - if (seen.count(next)) - continue; - seen.insert(next); - - if (auto CI = dyn_cast(next.first)) { - todo.push_back({CI->getOperand(0), CI}); - continue; - } - - // Assume zero size if realloc of undef pointer - if (isa(next.first)) { - B.SetInsertPoint(next.second); - B.CreateStore(ConstantInt::get(T, 0), AI); - continue; - } - - if (auto CE = dyn_cast(next.first)) { - if (CE->isCast()) { - todo.push_back({CE->getOperand(0), next.second}); - continue; - } - } - - if (auto C = dyn_cast(next.first)) { - if (C->isNullValue()) { - B.SetInsertPoint(next.second); - B.CreateStore(ConstantInt::get(T, 0), AI); - continue; - } - } - if (auto CI = dyn_cast(next.first)) { - // if negative or below 0xFFF this cannot possibly represent - // a real pointer, so ignore this case by setting to 0 - if (CI->isNegative() || CI->getLimitedValue() <= 0xFFF) { - B.SetInsertPoint(next.second); - B.CreateStore(ConstantInt::get(T, 0), AI); - continue; - } - } - - // Todo consider more general method for selects - if (auto SI = dyn_cast(next.first)) { - if (auto C1 = dyn_cast(SI->getTrueValue())) { - // if negative or below 0xFFF this cannot possibly represent - // a real pointer, so ignore this case by setting to 0 - if (C1->isNegative() || C1->getLimitedValue() <= 0xFFF) { - if (auto C2 = dyn_cast(SI->getFalseValue())) { - if (C2->isNegative() || C2->getLimitedValue() <= 0xFFF) { - B.SetInsertPoint(next.second); - B.CreateStore(ConstantInt::get(T, 0), AI); - continue; - } - } - } - } - } - - if (auto PN = dyn_cast(next.first)) { - for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { - todo.push_back({PN->getIncomingValue(i), - PN->getIncomingBlock(i)->getTerminator()}); - } - continue; - } - - if (auto CI = dyn_cast(next.first)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "malloc") { - B.SetInsertPoint(next.second); - B.CreateStore(CI->getArgOperand(0), AI); - continue; - } - if (F->getName() == "calloc") { - B.SetInsertPoint(next.second); - B.CreateStore(B.CreateMul(CI->getArgOperand(0), CI->getArgOperand(1)), - AI); - continue; - } - if (F->getName() == "realloc") { - assert(reallocSizes.find(CI) != reallocSizes.end()); - B.SetInsertPoint(next.second); - B.CreateStore(reallocSizes.find(CI)->second, AI); - continue; - } - } - } - - if (auto LI = dyn_cast(next.first)) { - bool success = false; - for (Instruction *prev = LI->getPrevNode(); prev != nullptr; - prev = prev->getPrevNode()) { - if (auto CI = dyn_cast(prev)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "posix_memalign" && - CI->getArgOperand(0) == LI->getOperand(0)) { - B.SetInsertPoint(next.second); - B.CreateStore(CI->getArgOperand(2), AI); - success = true; - break; - } - } - } - if (prev->mayWriteToMemory()) { - break; - } - } - if (success) - continue; - - auto v2 = simplifyLoad(LI); - if (v2) { - todo.push_back({v2, next.second}); - continue; - } - } - - EmitFailure("DynamicReallocSize", Loc->getDebugLoc(), Loc, - "could not statically determine size of realloc ", *Loc, - " - because of - ", *next.first); - return AI; - - std::string allocName; - switch (llvm::Triple(NewF->getParent()->getTargetTriple()).getOS()) { - case llvm::Triple::Linux: - case llvm::Triple::FreeBSD: - case llvm::Triple::NetBSD: - case llvm::Triple::OpenBSD: - case llvm::Triple::Fuchsia: - allocName = "malloc_usable_size"; - break; - - case llvm::Triple::Darwin: - case llvm::Triple::IOS: - case llvm::Triple::MacOSX: - case llvm::Triple::WatchOS: - case llvm::Triple::TvOS: - allocName = "malloc_size"; - break; - - case llvm::Triple::Win32: - allocName = "_msize"; - break; - - default: - llvm_unreachable("unknown reallocation for OS"); - } - - AttributeList list; - list = list.addFnAttribute(NewF->getContext(), Attribute::ReadOnly); - list = list.addParamAttribute(NewF->getContext(), 0, Attribute::ReadNone); - list = addFunctionNoCapture(NewF->getContext(), list, 0); - auto allocSize = NewF->getParent()->getOrInsertFunction( - allocName, - FunctionType::get( - IntegerType::get(NewF->getContext(), 8 * sizeof(size_t)), - {getInt8PtrTy(NewF->getContext())}, /*isVarArg*/ false), - list); - - B.SetInsertPoint(Loc); - Value *sz = B.CreateZExtOrTrunc(B.CreateCall(allocSize, {Ptr}), T); - B.CreateStore(sz, AI); - return AI; - - llvm_unreachable("DynamicReallocSize"); - } - return AI; -} - -void PreProcessCache::AlwaysInline(Function *NewF) { - - PreservedAnalyses PA; - PA.preserve(); - PA.preserve(); - FAM.invalidate(*NewF, PA); - SmallVector ToInline; - // TODO this logic should be combined with the dynamic loop emission - // to minimize the number of branches if the realloc is used for multiple - // values with the same bound. - for (auto &BB : *NewF) { - for (auto &I : make_early_inc_range(BB)) { - if (hasMetadata(&I, "enzyme_zerostack")) { - if (isa(getBaseObject(I.getOperand(0)))) { - I.eraseFromParent(); - continue; - } - } - if (auto CI = dyn_cast(&I)) { - if (!CI->getCalledFunction()) - continue; - if (CI->getCalledFunction()->hasFnAttribute(Attribute::AlwaysInline)) - ToInline.push_back(CI); - } - } - } - - for (auto CI : ToInline) { - InlineFunctionInfo IFI; -#if LLVM_VERSION_MAJOR >= 18 - auto F = CI->getCalledFunction(); - if (CI->getParent()->IsNewDbgInfoFormat != F->IsNewDbgInfoFormat) { - if (CI->getParent()->IsNewDbgInfoFormat) { - F->convertToNewDbgValues(); - } else { - F->convertFromNewDbgValues(); - } - } -#endif - InlineFunction(*CI, IFI); - } -} - -// Simplify all extractions to use inserted values, if possible. -void simplifyExtractions(Function *NewF) { - // First rewrite/remove any extractions - for (auto &BB : *NewF) { - IRBuilder<> B(&BB); - auto first = BB.begin(); - auto last = BB.empty() ? BB.end() : std::prev(BB.end()); - for (auto it = first; it != last;) { - auto inst = &*it; - // We iterate first here, since we may delete the instruction - // in the body - ++it; - if (auto E = dyn_cast(inst)) { - auto rep = GradientUtils::extractMeta(B, E->getAggregateOperand(), - E->getIndices(), E->getName(), - /*fallback*/ false); - if (rep) { - E->replaceAllUsesWith(rep); - E->eraseFromParent(); - } - } - } - } - // Now that there may be unused insertions, delete them. We keep a list of - // todo's since deleting an insertvalue may cause a different insertvalue to - // have no uses - SmallVector todo; - for (auto &BB : *NewF) { - for (auto &inst : BB) - if (auto I = dyn_cast(&inst)) { - if (I->getNumUses() == 0) - todo.push_back(I); - } - } - while (todo.size()) { - auto I = todo.pop_back_val(); - auto op = I->getAggregateOperand(); - I->eraseFromParent(); - if (auto I2 = dyn_cast(op)) - if (I2->getNumUses() == 0) - todo.push_back(I2); - } -} - -void PreProcessCache::LowerAllocAddr(Function *NewF) { - simplifyExtractions(NewF); - SmallVector Todo; - for (auto &BB : *NewF) { - for (auto &I : BB) { - if (hasMetadata(&I, "enzyme_backstack")) { - Todo.push_back(&I); - // TODO - // I.eraseMetadata("enzyme_backstack"); - } - } - } - for (auto T : Todo) { - auto T0 = T->getOperand(0); - if (auto CI = dyn_cast(T0)) - T0 = CI->getOperand(0); - auto AI = cast(T0); - llvm::Value *AIV = AI; -#if LLVM_VERSION_MAJOR < 17 - if (AIV->getType()->getPointerElementType() != - T->getType()->getPointerElementType()) { - IRBuilder<> B(AI->getNextNode()); - AIV = B.CreateBitCast( - AIV, PointerType::get( - T->getType()->getPointerElementType(), - cast(AI->getType())->getAddressSpace())); - } -#endif - RecursivelyReplaceAddressSpace(T, AIV, /*legal*/ true); - } -} - -/// Calls to realloc with an appropriate implementation -void PreProcessCache::ReplaceReallocs(Function *NewF, bool mem2reg) { - if (mem2reg) { - auto PA = PromotePass().run(*NewF, FAM); - FAM.invalidate(*NewF, PA); -#if !defined(FLANG) - PA = GVNPass().run(*NewF, FAM); -#else - PA = GVN().run(*NewF, FAM); -#endif - FAM.invalidate(*NewF, PA); - } - - SmallVector ToConvert; - std::map reallocSizes; - IntegerType *T = nullptr; - - for (auto &BB : *NewF) { - for (auto &I : BB) { - if (auto CI = dyn_cast(&I)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "realloc") { - ToConvert.push_back(CI); - IRBuilder<> B(CI->getNextNode()); - T = cast(CI->getArgOperand(1)->getType()); - reallocSizes[CI] = B.CreatePHI(T, 0); - } - } - } - } - } - - SmallVector memoryLocations; - - for (auto CI : ToConvert) { - assert(T); - AllocaInst *AI = - OldAllocationSize(CI->getArgOperand(0), CI, NewF, T, reallocSizes); - - BasicBlock *resize = - BasicBlock::Create(CI->getContext(), "resize" + CI->getName(), NewF); - assert(resize->getParent() == NewF); - - BasicBlock *splitParent = CI->getParent(); - BasicBlock *nextBlock = splitParent->splitBasicBlock(CI); - - splitParent->getTerminator()->eraseFromParent(); - IRBuilder<> B(splitParent); - - Value *p = CI->getArgOperand(0); - Value *req = CI->getArgOperand(1); - Value *old = B.CreateLoad(AI->getAllocatedType(), AI); - Value *cmp = B.CreateICmpULE(req, old); - // if (req < old) - B.CreateCondBr(cmp, nextBlock, resize); - - B.SetInsertPoint(resize); - // size_t newsize = nextPowerOfTwo(req); - // void* next = malloc(newsize); - // memcpy(next, p, newsize); - // free(p); - // return { next, newsize }; - - Value *newsize = nextPowerOfTwo(B, req); - - Module *M = NewF->getParent(); - Type *BPTy = getInt8PtrTy(NewF->getContext()); - auto MallocFunc = - M->getOrInsertFunction("malloc", BPTy, newsize->getType()); - auto next = B.CreateCall(MallocFunc, newsize); - B.SetInsertPoint(resize); - - auto volatile_arg = ConstantInt::getFalse(CI->getContext()); - - Value *nargs[] = {next, p, old, volatile_arg}; - - Type *tys[] = {next->getType(), p->getType(), old->getType()}; - - auto memcpyF = - getIntrinsicDeclaration(NewF->getParent(), Intrinsic::memcpy, tys); - - auto mem = cast(B.CreateCall(memcpyF, nargs)); - mem->setCallingConv(memcpyF->getCallingConv()); - - Type *VoidTy = Type::getVoidTy(M->getContext()); - auto FreeFunc = M->getOrInsertFunction("free", VoidTy, BPTy); - B.CreateCall(FreeFunc, p); - B.SetInsertPoint(resize); - - B.CreateBr(nextBlock); - - // else - // return { p, old } - B.SetInsertPoint(&*nextBlock->begin()); - - PHINode *retPtr = B.CreatePHI(CI->getType(), 2); - retPtr->addIncoming(p, splitParent); - retPtr->addIncoming(next, resize); - CI->replaceAllUsesWith(retPtr); - std::string nam = CI->getName().str(); - CI->setName(""); - retPtr->setName(nam); - Value *nextSize = B.CreateSelect(cmp, old, req); - reallocSizes[CI]->replaceAllUsesWith(nextSize); - cast(reallocSizes[CI])->eraseFromParent(); - reallocSizes[CI] = nextSize; - } - - for (auto CI : ToConvert) { - CI->eraseFromParent(); - } - - PreservedAnalyses PA; - FAM.invalidate(*NewF, PA); - - PA = PromotePass().run(*NewF, FAM); - FAM.invalidate(*NewF, PA); -} - -Function *CreateMPIWrapper(Function *F) { - std::string name = ("enzyme_wrapmpi$$" + F->getName() + "#").str(); - if (auto W = F->getParent()->getFunction(name)) - return W; - Type *types = {F->getFunctionType()->getParamType(0)}; - auto FT = FunctionType::get(F->getReturnType(), types, false); - Function *W = Function::Create(FT, GlobalVariable::InternalLinkage, name, - F->getParent()); - llvm::Attribute::AttrKind attrs[] = { - Attribute::WillReturn, - Attribute::MustProgress, -#if LLVM_VERSION_MAJOR < 16 - Attribute::ReadOnly, -#endif - Attribute::Speculatable, - Attribute::NoUnwind, - Attribute::AlwaysInline, - Attribute::NoFree, - Attribute::NoSync, -#if LLVM_VERSION_MAJOR < 16 - Attribute::InaccessibleMemOnly -#endif - }; - for (auto attr : attrs) { - W->addFnAttr(attr); - } -#if LLVM_VERSION_MAJOR >= 16 - W->setOnlyAccessesInaccessibleMemory(); - W->setOnlyReadsMemory(); -#endif - W->addFnAttr(Attribute::get(F->getContext(), "enzyme_inactive")); - BasicBlock *entry = BasicBlock::Create(W->getContext(), "entry", W); - IRBuilder<> B(entry); - auto alloc = B.CreateAlloca(F->getReturnType()); - Value *args[] = {W->arg_begin(), alloc}; - - auto T = F->getFunctionType()->getParamType(1); - if (!isa(T)) { - assert(isa(T)); - args[1] = B.CreatePtrToInt(args[1], T); - } - B.CreateCall(F, args); - B.CreateRet(B.CreateLoad(F->getReturnType(), alloc)); - return W; -} - -static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) { - DominatorTree &DT = FAM.getResult(NewF); - SmallVector Todo; - SmallVector OMPBounds; - for (auto &BB : NewF) { - for (auto &I : BB) { - if (auto CI = dyn_cast(&I)) { - Function *Fn = CI->getCalledFunction(); - if (Fn == nullptr) - continue; - if (Fn->getName() == "MPI_Comm_rank" || - Fn->getName() == "PMPI_Comm_rank" || - Fn->getName() == "MPI_Comm_size" || - Fn->getName() == "PMPI_Comm_size") { - Todo.push_back(CI); - } - if (Fn->getName() == "__kmpc_for_static_init_4" || - Fn->getName() == "__kmpc_for_static_init_4u" || - Fn->getName() == "__kmpc_for_static_init_8" || - Fn->getName() == "__kmpc_for_static_init_8u") { - OMPBounds.push_back(CI); - } - } - } - } - if (Todo.size() == 0 && OMPBounds.size() == 0) - return; - for (auto CI : Todo) { - IRBuilder<> B(CI); - Value *arg[] = {CI->getArgOperand(0)}; - SmallVector Defs; - CI->getOperandBundlesAsDefs(Defs); - CallBase *res = nullptr; - if (auto II = dyn_cast(CI)) - res = B.CreateInvoke(CreateMPIWrapper(CI->getCalledFunction()), - II->getNormalDest(), II->getUnwindDest(), arg, Defs); - else - res = B.CreateCall(CreateMPIWrapper(CI->getCalledFunction()), arg, Defs); - Value *storePointer = CI->getArgOperand(1); - - // Comm_rank and Comm_size return Err, assume 0 is success - CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0)); - CI->eraseFromParent(); - - while (auto Cast = dyn_cast(storePointer)) { - storePointer = Cast->getOperand(0); - if (Cast->use_empty()) - Cast->eraseFromParent(); - } - - B.SetInsertPoint(res); - - if (auto PT = dyn_cast(storePointer->getType())) { - (void)PT; -#if LLVM_VERSION_MAJOR < 17 - if (PT->getContext().supportsTypedPointers()) { - if (PT->getPointerElementType() != res->getType()) - storePointer = B.CreateBitCast( - storePointer, - PointerType::get(res->getType(), PT->getAddressSpace())); - } -#endif - } else { - assert(isa(storePointer->getType())); - storePointer = B.CreateIntToPtr(storePointer, - PointerType::getUnqual(res->getType())); - } - if (isa(storePointer)) { - // If this is only loaded from, immedaitely replace - // Immediately replace all dominated stores. - SmallVector LI; - bool nonload = false; - for (auto &U : storePointer->uses()) { - if (auto L = dyn_cast(U.getUser())) { - LI.push_back(L); - } else - nonload = true; - } - if (!nonload) { - for (auto L : LI) { - if (DT.dominates(res, L)) { - L->replaceAllUsesWith(res); - L->eraseFromParent(); - } - } - } - } - if (auto II = dyn_cast(res)) { - B.SetInsertPoint(II->getNormalDest()->getFirstNonPHI()); - } else { - B.SetInsertPoint(res->getNextNode()); - } - B.CreateStore(res, storePointer); - } - for (auto Bound : OMPBounds) { - for (int i = 4; i <= 6; i++) { - auto AI = cast(Bound->getArgOperand(i)); - IRBuilder<> B(AI); - auto AI2 = B.CreateAlloca(AI->getAllocatedType(), nullptr, - AI->getName() + "_smpl"); - B.SetInsertPoint(Bound); - B.CreateStore(B.CreateLoad(AI->getAllocatedType(), AI), AI2); - Bound->setArgOperand(i, AI2); - if (auto II = dyn_cast(Bound)) { - B.SetInsertPoint(II->getNormalDest()->getFirstNonPHI()); - } else { - B.SetInsertPoint(Bound->getNextNode()); - } - B.CreateStore(B.CreateLoad(AI2->getAllocatedType(), AI2), AI); - addCallSiteNoCapture(Bound, i); - } - } - PreservedAnalyses PA; - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - FAM.invalidate(NewF, PA); -} - -/// Perform recursive inlinining on NewF up to the given limit -static void ForceRecursiveInlining(Function *NewF, size_t Limit) { - std::map RecurResults; - for (size_t count = 0; count < Limit; count++) { - for (auto &BB : *NewF) { - for (auto &I : BB) { - if (auto CI = dyn_cast(&I)) { - if (CI->getCalledFunction() == nullptr) - continue; - if (CI->getCalledFunction()->empty()) - continue; - if (startsWith(CI->getCalledFunction()->getName(), - "_ZN3std2io5stdio6_print")) - continue; - if (startsWith(CI->getCalledFunction()->getName(), "_ZN4core3fmt")) - continue; - if (startsWith(CI->getCalledFunction()->getName(), - "enzyme_wrapmpi$$")) - continue; - if (CI->getCalledFunction()->hasFnAttribute( - Attribute::ReturnsTwice) || - CI->getCalledFunction()->hasFnAttribute(Attribute::NoInline)) - continue; - if (IsFunctionRecursive(CI->getCalledFunction(), RecurResults)) { - LLVM_DEBUG(llvm::dbgs() - << "not inlining recursive " - << CI->getCalledFunction()->getName() << "\n"); - continue; - } - InlineFunctionInfo IFI; - InlineFunction(*CI, IFI); - goto outermostContinue; - } - } - } - - // No functions were inlined, break - break; - - outermostContinue:; - } -} - -void CanonicalizeLoops(Function *F, FunctionAnalysisManager &FAM) { - LoopSimplifyPass().run(*F, FAM); - DominatorTree &DT = FAM.getResult(*F); - LoopInfo &LI = FAM.getResult(*F); - AssumptionCache &AC = FAM.getResult(*F); - TargetLibraryInfo &TLI = FAM.getResult(*F); - MustExitScalarEvolution SE(*F, TLI, AC, DT, LI); - for (Loop *L : LI.getLoopsInPreorder()) { - auto pair = - InsertNewCanonicalIV(L, Type::getInt64Ty(F->getContext()), "iv"); - PHINode *CanonicalIV = pair.first; - assert(CanonicalIV); - RemoveRedundantIVs( - L->getHeader(), CanonicalIV, pair.second, SE, - [&](Instruction *I, Value *V) { I->replaceAllUsesWith(V); }, - [&](Instruction *I) { I->eraseFromParent(); }); - } - PreservedAnalyses PA; - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - FAM.invalidate(*F, PA); -} - -void RemoveRedundantPHI(Function *F, FunctionAnalysisManager &FAM) { - DominatorTree &DT = FAM.getResult(*F); - for (BasicBlock &BB : *F) { - for (BasicBlock::iterator II = BB.begin(); isa(II);) { - PHINode *PN = cast(II); - ++II; - SmallPtrSet vals; - SmallPtrSet done; - SmallVector todo = {PN}; - while (todo.size() > 0) { - PHINode *N = todo.back(); - todo.pop_back(); - if (done.count(N)) - continue; - done.insert(N); - if (vals.size() == 0 && todo.size() == 0 && PN != N && - DT.dominates(N, PN)) { - vals.insert(N); - break; - } - for (auto &v : N->incoming_values()) { - if (isa(v)) - continue; - if (auto NN = dyn_cast(v)) { - todo.push_back(NN); - continue; - } - vals.insert(v); - if (vals.size() > 1) - break; - } - if (vals.size() > 1) - break; - } - if (vals.size() == 1) { - auto V = *vals.begin(); - if (!isa(V) || DT.dominates(cast(V), PN)) { - PN->replaceAllUsesWith(V); - PN->eraseFromParent(); - } - } - } - } -} - -PreProcessCache::PreProcessCache() { - // Explicitly chose AA passes that are stateless - // and will not be invalidated - FAM.registerPass([] { return TypeBasedAA(); }); - FAM.registerPass([] { return BasicAA(); }); - MAM.registerPass([] { return GlobalsAA(); }); - // CallGraphAnalysis required for GlobalsAA - MAM.registerPass([] { return CallGraphAnalysis(); }); - - FAM.registerPass([] { return ScopedNoAliasAA(); }); - - // SCEVAA causes some breakage/segfaults - // disable for now, consider enabling in future - // FAM.registerPass([] { return SCEVAA(); }); - -#if LLVM_VERSION_MAJOR < 16 - if (EnzymeAggressiveAA) - FAM.registerPass([] { return CFLSteensAA(); }); -#endif - - MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); - FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); - - LAM.registerPass([&] { return FunctionAnalysisManagerLoopProxy(FAM); }); - FAM.registerPass([&] { return LoopAnalysisManagerFunctionProxy(LAM); }); - - FAM.registerPass([] { - auto AM = AAManager(); - AM.registerFunctionAnalysis(); - AM.registerFunctionAnalysis(); - AM.registerModuleAnalysis(); - AM.registerFunctionAnalysis(); - - // broken for different reasons - // AM.registerFunctionAnalysis(); - -#if LLVM_VERSION_MAJOR < 16 - if (EnzymeAggressiveAA) - AM.registerFunctionAnalysis(); -#endif - - return AM; - }); - - PassBuilder PB; - PB.registerModuleAnalyses(MAM); - PB.registerFunctionAnalyses(FAM); - PB.registerLoopAnalyses(LAM); -} - -llvm::AAResults & -PreProcessCache::getAAResultsFromFunction(llvm::Function *NewF) { - return FAM.getResult(*NewF); -} - -void setFullWillReturn(Function *NewF) { - for (auto &BB : *NewF) { - for (auto &I : BB) { - if (auto CI = dyn_cast(&I)) { - CI->addFnAttr(Attribute::WillReturn); - CI->addFnAttr(Attribute::MustProgress); - } - if (auto CI = dyn_cast(&I)) { - CI->addFnAttr(Attribute::WillReturn); - CI->addFnAttr(Attribute::MustProgress); - } - } - } -} - -void SplitPHIs(llvm::Function &F) { - SetVector todo; - for (auto &BB : F) { - for (auto &I : BB) { - if (isa(&I)) { - todo.insert(&I); - } else if (isa(&I)) { - todo.insert(&I); - } - } - } - while (todo.size()) { - auto cur = todo.pop_back_val(); - IRBuilder<> B(cur); - auto ST = dyn_cast(cur->getType()); - if (!ST) - continue; - bool justExtract = true; - for (auto U : cur->users()) { - if (!isa(U)) { - justExtract = false; - break; - } - if (cast(U)->getIndices().size() == 0) { - justExtract = false; - break; - } - } - if (!justExtract) - continue; - - SmallVector replacements; - for (size_t i = 0, e = ST->getNumElements(); i < e; i++) { - if (auto cur2 = dyn_cast(cur)) { - auto nPhi = - B.CreatePHI(ST->getElementType(i), cur2->getNumIncomingValues(), - cur->getName() + ".extract." + std::to_string(i)); - for (auto &&[blk, val] : - llvm::zip(cur2->blocks(), cur2->incoming_values())) { - IRBuilder B2(blk->getTerminator()); - nPhi->addIncoming(GradientUtils::extractMeta(B2, val, i), blk); - } - replacements.push_back(nPhi); - todo.insert(nPhi); - } else { - auto cur3 = cast(cur); - auto rep = B.CreateSelect( - cur3->getCondition(), - GradientUtils::extractMeta(B, cur3->getTrueValue(), i), - GradientUtils::extractMeta(B, cur3->getFalseValue(), i), - cur->getName() + ".extract." + std::to_string(i)); - replacements.push_back(rep); - if (auto sel = dyn_cast(rep)) - todo.insert(sel); - } - } - for (auto &U : make_early_inc_range(cur->uses())) { - auto user = cast(U.getUser()); - Value *rep = replacements[user->getIndices()[0]]; - IRBuilder<> B(user); - if (user->getIndices().size() > 1) - rep = B.CreateExtractValue(rep, user->getIndices().slice(1)); - assert(rep->getType() == user->getType()); - user->replaceAllUsesWith(rep); - user->eraseFromParent(); - } - cur->eraseFromParent(); - } -} - -Function *PreProcessCache::preprocessForClone(Function *F, - DerivativeMode mode) { - - TimeTraceScope timeScope("preprocessForClone", F->getName()); - - if (mode == DerivativeMode::ReverseModeGradient) - mode = DerivativeMode::ReverseModePrimal; - if (mode == DerivativeMode::ForwardModeSplit) - mode = DerivativeMode::ReverseModePrimal; - - // If we've already processed this, return the previous version - // and derive aliasing information - if (cache.find(std::make_pair(F, mode)) != cache.end()) { - Function *NewF = cache[std::make_pair(F, mode)]; - return NewF; - } - - Function *NewF = - Function::Create(F->getFunctionType(), F->getLinkage(), - "preprocess_" + F->getName(), F->getParent()); - - ValueToValueMapTy VMap; - for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) { - VMap[i] = j; - j->setName(i->getName()); - if (EnzymeNoAlias && j->getType()->isPointerTy()) { - j->addAttr(Attribute::NoAlias); - } - ++i; - ++j; - } - - SmallVector Returns; - - if (!F->empty()) { - CloneFunctionInto( - NewF, F, VMap, - /*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); - } - CloneOrigin[NewF] = F; - NewF->setAttributes(F->getAttributes()); - if (EnzymeNoAlias) - for (auto j = NewF->arg_begin(); j != NewF->arg_end(); j++) { - if (j->getType()->isPointerTy()) { - j->addAttr(Attribute::NoAlias); - } - } - NewF->addFnAttr(Attribute::WillReturn); - NewF->addFnAttr(Attribute::MustProgress); - setFullWillReturn(NewF); - - if (EnzymePreopt) { - if (EnzymeInline) { - ForceRecursiveInlining(NewF, /*Limit*/ EnzymeInlineCount); - setFullWillReturn(NewF); - PreservedAnalyses PA; - FAM.invalidate(*NewF, PA); - } - } - - { - SmallVector ItersToErase; - for (auto &BB : *NewF) { - for (auto &I : BB) { - - if (auto CI = dyn_cast(&I)) { - - Function *called = CI->getCalledFunction(); - if (auto castinst = dyn_cast(CI->getCalledOperand())) { - if (castinst->isCast()) { - if (auto fn = dyn_cast(castinst->getOperand(0))) - called = fn; - } - } - - if (called && called->getName() == "__enzyme_iter") { - ItersToErase.push_back(CI); - } - } - } - } - for (auto Call : ItersToErase) { - IRBuilder<> B(Call); - Call->setArgOperand( - 0, B.CreateAdd(Call->getArgOperand(0), Call->getArgOperand(1))); - } - } - - // Assume allocations do not return null - { - TargetLibraryInfo &TLI = FAM.getResult(*F); - SmallVector CmpsToErase; - SmallVector BranchesToErase; - for (auto &BB : *NewF) { - for (auto &I : BB) { - if (auto IC = dyn_cast(&I)) { - if (!IC->isEquality()) - continue; - for (int i = 0; i < 2; i++) { - if (isa(IC->getOperand(1 - i))) - if (isAllocationCall(IC->getOperand(i), TLI)) { - for (auto U : IC->users()) { - if (auto BI = dyn_cast(U)) - BranchesToErase.push_back(BI->getParent()); - } - IC->replaceAllUsesWith( - IC->getPredicate() == ICmpInst::ICMP_NE - ? ConstantInt::getTrue(I.getContext()) - : ConstantInt::getFalse(I.getContext())); - CmpsToErase.push_back(&I); - break; - } - } - } - } - } - for (auto I : CmpsToErase) - I->eraseFromParent(); - for (auto BE : BranchesToErase) - ConstantFoldTerminator(BE); - } - - SimplifyMPIQueries(*NewF, FAM); - { - auto PA = PromotePass().run(*NewF, FAM); - FAM.invalidate(*NewF, PA); - } - - if (EnzymeLowerGlobals) { - SmallVector Calls; - SmallVector Returns; - for (BasicBlock &BB : *NewF) { - for (Instruction &I : BB) { - if (auto CI = dyn_cast(&I)) { - Calls.push_back(CI); - } - if (auto RI = dyn_cast(&I)) { - Returns.push_back(RI); - } - } - } - - // TODO consider using TBAA and globals as well - // instead of just BasicAA - AAResults AA2(FAM.getResult(*NewF)); - AA2.addAAResult(FAM.getResult(*NewF)); - AA2.addAAResult(FAM.getResult(*NewF)); - AA2.addAAResult(FAM.getResult(*NewF)); - - for (auto &g : NewF->getParent()->globals()) { - bool inF = false; - { - std::set seen; - std::deque todo = {(Constant *)&g}; - while (todo.size()) { - auto GV = todo.front(); - todo.pop_front(); - if (!seen.insert(GV).second) - continue; - for (auto u : GV->users()) { - if (auto C = dyn_cast(u)) { - todo.push_back(C); - } else if (auto I = dyn_cast(u)) { - if (I->getParent()->getParent() == NewF) { - inF = true; - goto doneF; - } - } - } - } - } - doneF:; - if (inF) { - bool activeCall = false; - bool hasWrite = false; - MemoryLocation Loc = - MemoryLocation(&g, LocationSize::beforeOrAfterPointer()); - - for (CallInst *CI : Calls) { - if (isa(CI)) - continue; - Function *F = CI->getCalledFunction(); - if (auto castinst = dyn_cast(CI->getCalledOperand())) { - if (castinst->isCast()) - if (auto fn = dyn_cast(castinst->getOperand(0))) { - F = fn; - } - } - if (F && isMemFreeLibMFunction(F->getName())) { - continue; - } - if (F && F->getName().contains("__enzyme_integer")) { - continue; - } - if (F && F->getName().contains("__enzyme_pointer")) { - continue; - } - if (F && F->getName().contains("__enzyme_float")) { - continue; - } - if (F && F->getName().contains("__enzyme_double")) { - continue; - } - if (F && (startsWith(F->getName(), "f90io") || - F->getName() == "ftnio_fmt_write64" || - F->getName() == "__mth_i_ipowi" || - F->getName() == "f90_pausea")) { - continue; - } - if (llvm::isModOrRefSet(AA2.getModRefInfo(CI, Loc))) { - llvm::errs() << " failed to inline global: " << g << " due to " - << *CI << "\n"; - activeCall = true; - break; - } - } - - if (!activeCall) { - std::set seen; - std::deque todo = {(Value *)&g}; - while (todo.size()) { - auto GV = todo.front(); - todo.pop_front(); - if (!seen.insert(GV).second) - continue; - for (auto u : GV->users()) { - if (isa(u) || isa(u) || - isa(u) || isa(u)) { - todo.push_back(u); - continue; - } - - if (auto II = dyn_cast(u)) { - if (isIntelSubscriptIntrinsic(*II)) { - todo.push_back(u); - continue; - } - } - - if (auto CI = dyn_cast(u)) { - Function *F = CI->getCalledFunction(); - if (auto castinst = - dyn_cast(CI->getCalledOperand())) { - if (castinst->isCast()) - if (auto fn = dyn_cast(castinst->getOperand(0))) { - F = fn; - } - } - if (F && isMemFreeLibMFunction(F->getName())) { - continue; - } - if (F && F->getName().contains("__enzyme_integer")) { - continue; - } - if (F && F->getName().contains("__enzyme_pointer")) { - continue; - } - if (F && F->getName().contains("__enzyme_float")) { - continue; - } - if (F && F->getName().contains("__enzyme_double")) { - continue; - } - if (F && (startsWith(F->getName(), "f90io") || - F->getName() == "ftnio_fmt_write64" || - F->getName() == "__mth_i_ipowi" || - F->getName() == "f90_pausea")) { - continue; - } - - if (couldFunctionArgumentCapture(CI, GV)) { - hasWrite = true; - goto endCheck; - } - - if (llvm::isModSet(AA2.getModRefInfo(CI, Loc))) { - hasWrite = true; - goto endCheck; - } - } - - else if (auto I = dyn_cast(u)) { - if (llvm::isModSet(AA2.getModRefInfo(I, Loc))) { - hasWrite = true; - goto endCheck; - } - } - } - } - } - - endCheck:; - if (!activeCall && hasWrite) { - IRBuilder<> bb(&NewF->getEntryBlock(), NewF->getEntryBlock().begin()); - AllocaInst *antialloca = bb.CreateAlloca( - g.getValueType(), g.getType()->getPointerAddressSpace(), nullptr, - g.getName() + "_local"); - - if (g.getAlignment()) { - antialloca->setAlignment(Align(g.getAlignment())); - } - - std::map remap; - remap[&g] = antialloca; - - std::deque todo = {&g}; - while (todo.size()) { - auto GV = todo.front(); - todo.pop_front(); - if (&g != GV && remap.find(GV) != remap.end()) - continue; - Value *replaced = nullptr; - if (remap.find(GV) != remap.end()) { - replaced = remap[GV]; - } else if (auto CE = dyn_cast(GV)) { - auto I = CE->getAsInstruction(); - bb.Insert(I); - assert(isa(I->getOperand(0))); - assert(remap[cast(I->getOperand(0))]); - I->setOperand(0, remap[cast(I->getOperand(0))]); - replaced = remap[GV] = I; - } - assert(replaced && "unhandled constantexpr"); - - SmallVector, 4> uses; - for (Use &U : GV->uses()) { - if (auto I = dyn_cast(U.getUser())) { - if (I->getParent()->getParent() == NewF) { - uses.emplace_back(I, U.getOperandNo()); - } - } - if (auto C = dyn_cast(U.getUser())) { - assert(C != &g); - todo.push_back(C); - } - } - for (auto &U : uses) { - U.first->setOperand(U.second, replaced); - } - } - - Value *args[] = { - bb.CreateBitCast(antialloca, getInt8PtrTy(g.getContext())), - bb.CreateBitCast(&g, getInt8PtrTy(g.getContext())), - ConstantInt::get( - Type::getInt64Ty(g.getContext()), - g.getParent()->getDataLayout().getTypeAllocSizeInBits( - g.getValueType()) / - 8), - ConstantInt::getFalse(g.getContext())}; - - Type *tys[] = {args[0]->getType(), args[1]->getType(), - args[2]->getType()}; - auto intr = - getIntrinsicDeclaration(g.getParent(), Intrinsic::memcpy, tys); - { - - auto cal = bb.CreateCall(intr, args); - if (g.getAlignment()) { - cal->addParamAttr( - 0, Attribute::getWithAlignment(g.getContext(), - Align(g.getAlignment()))); - cal->addParamAttr( - 1, Attribute::getWithAlignment(g.getContext(), - Align(g.getAlignment()))); - } - } - - std::swap(args[0], args[1]); - - for (ReturnInst *RI : Returns) { - IRBuilder<> IB(RI); - auto cal = IB.CreateCall(intr, args); - if (g.getAlignment()) { - cal->addParamAttr( - 0, Attribute::getWithAlignment(g.getContext(), - Align(g.getAlignment()))); - cal->addParamAttr( - 1, Attribute::getWithAlignment(g.getContext(), - Align(g.getAlignment()))); - } - } - } - } - } - - auto Level = OptimizationLevel::O2; - - PassBuilder PB; - FunctionPassManager FPM = - PB.buildFunctionSimplificationPipeline(Level, ThinOrFullLTOPhase::None); - auto PA = FPM.run(*F, FAM); - FAM.invalidate(*F, PA); - } - - if (EnzymePreopt) { - { - auto PA = LowerInvokePass().run(*NewF, FAM); - FAM.invalidate(*NewF, PA); - } - { - auto PA = UnreachableBlockElimPass().run(*NewF, FAM); - FAM.invalidate(*NewF, PA); - } - - { - auto PA = PromotePass().run(*NewF, FAM); - FAM.invalidate(*NewF, PA); - } - - { -#if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG) - auto PA = SROAPass(llvm::SROAOptions::ModifyCFG).run(*NewF, FAM); -#elif !defined(FLANG) - auto PA = SROAPass().run(*NewF, FAM); -#else - auto PA = SROA().run(*NewF, FAM); -#endif - FAM.invalidate(*NewF, PA); - } - - if (mode != DerivativeMode::ForwardMode) - ReplaceReallocs(NewF); - - { -#if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG) - auto PA = SROAPass(llvm::SROAOptions::PreserveCFG).run(*NewF, FAM); -#elif !defined(FLANG) - auto PA = SROAPass().run(*NewF, FAM); -#else - auto PA = SROA().run(*NewF, FAM); -#endif - FAM.invalidate(*NewF, PA); - } - - SimplifyCFGOptions scfgo; - { - auto PA = SimplifyCFGPass(scfgo).run(*NewF, FAM); - FAM.invalidate(*NewF, PA); - } - } - - { - SplitPHIs(*NewF); - PreservedAnalyses PA; - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - } - - if (mode != DerivativeMode::ForwardMode) - ReplaceReallocs(NewF); - - if (mode == DerivativeMode::ReverseModePrimal || - mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ReverseModeCombined) { - // For subfunction calls upgrade stack allocations to mallocs - // to ensure availability in the reverse pass - auto unreachable = getGuaranteedUnreachable(NewF); - UpgradeAllocasToMallocs(NewF, mode, unreachable); - } - - CanonicalizeLoops(NewF, FAM); - RemoveRedundantPHI(NewF, FAM); - - // Run LoopSimplifyPass to ensure preheaders exist on all loops - { - auto PA = LoopSimplifyPass().run(*NewF, FAM); - FAM.invalidate(*NewF, PA); - } - - { - for (auto &BB : *NewF) { - for (auto &I : make_early_inc_range(BB)) { - if (auto MTI = dyn_cast(&I)) { - - if (auto CI = dyn_cast(MTI->getOperand(2))) { - if (CI->getValue() == 0) { - MTI->eraseFromParent(); - } - } - } - } - } - - PreservedAnalyses PA; - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - PA.preserve(); - - FAM.invalidate(*NewF, PA); - - if (EnzymeNameInstructions) { - for (auto &Arg : NewF->args()) { - if (!Arg.hasName()) - Arg.setName("arg"); - } - for (BasicBlock &BB : *NewF) { - if (!BB.hasName()) - BB.setName("bb"); - - for (Instruction &I : BB) { - if (!I.hasName() && !I.getType()->isVoidTy()) - I.setName("i"); - } - } - } - } - - if (EnzymePHIRestructure) { - if (false) { - reset:; - PreservedAnalyses PA; - FAM.invalidate(*NewF, PA); - } - - SmallVector MultiBlocks; - for (auto &B : *NewF) { - if (B.hasNPredecessorsOrMore(3)) - MultiBlocks.push_back(&B); - } - - LoopInfo &LI = FAM.getResult(*NewF); - for (BasicBlock *B : MultiBlocks) { - - // Map of function edges to list of values possible - std::map, - std::set> - done; - { - std::deque, - BasicBlock *>> - Q; // newblock, target - - for (auto P : predecessors(B)) { - Q.emplace_back(std::make_pair(P, B), P); - } - - for (std::tuple< - std::pair, - BasicBlock *> - trace; - Q.size() > 0;) { - trace = Q.front(); - Q.pop_front(); - auto edge = std::get<0>(trace); - auto block = edge.first; - auto target = std::get<1>(trace); - - if (done[edge].count(target)) - continue; - done[edge].insert(target); - - Loop *blockLoop = LI.getLoopFor(block); - - for (BasicBlock *Pred : predecessors(block)) { - // Don't go up the backedge as we can use the last value if desired - // via lcssa - if (blockLoop && blockLoop->getHeader() == block && - blockLoop == LI.getLoopFor(Pred)) - continue; - - Q.push_back( - std::tuple, BasicBlock *>( - std::make_pair(Pred, block), target)); - } - } - } - - SmallPtrSet Preds; - for (auto &pair : done) { - Preds.insert(pair.first.first); - } - - for (auto BB : Preds) { - bool illegal = false; - SmallPtrSet UnionSet; - size_t numSuc = 0; - for (BasicBlock *sucI : successors(BB)) { - numSuc++; - const auto &SI = done[std::make_pair(BB, sucI)]; - if (SI.size() == 0) { - // sucI->getName(); - illegal = true; - break; - } - for (auto si : SI) { - UnionSet.insert(si); - - for (BasicBlock *sucJ : successors(BB)) { - if (sucI == sucJ) - continue; - if (done[std::make_pair(BB, sucJ)].count(si)) { - illegal = true; - goto endIllegal; - } - } - } - } - endIllegal:; - - if (!illegal && numSuc > 1 && !B->hasNPredecessors(UnionSet.size())) { - BasicBlock *Ins = - BasicBlock::Create(BB->getContext(), "tmpblk", BB->getParent()); - IRBuilder<> Builder(Ins); - for (auto &phi : B->phis()) { - auto nphi = Builder.CreatePHI(phi.getType(), 2); - SmallVector Blocks; - - for (auto blk : UnionSet) { - nphi->addIncoming(phi.getIncomingValueForBlock(blk), blk); - phi.removeIncomingValue(blk, /*deleteifempty*/ false); - } - - phi.addIncoming(nphi, Ins); - } - Builder.CreateBr(B); - for (auto blk : UnionSet) { - auto term = blk->getTerminator(); - for (unsigned Idx = 0, NumSuccessors = term->getNumSuccessors(); - Idx != NumSuccessors; ++Idx) - if (term->getSuccessor(Idx) == B) - term->setSuccessor(Idx, Ins); - } - goto reset; - } - } - } - } - - if (EnzymePrint) - llvm::errs() << "after simplification :\n" << *NewF << "\n"; - - if (llvm::verifyFunction(*NewF, &llvm::errs())) { - llvm::errs() << *NewF << "\n"; - report_fatal_error("function failed verification (1)"); - } - cache[std::make_pair(F, mode)] = NewF; - return NewF; -} - -FunctionType *getFunctionTypeForClone( - llvm::FunctionType *FTy, DerivativeMode mode, unsigned width, - llvm::Type *additionalArg, llvm::ArrayRef constant_args, - bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE returnType) { - SmallVector RetTypes; - if (returnValue == ReturnType::ArgsWithReturn || - returnValue == ReturnType::Return) { - if (returnType != DIFFE_TYPE::CONSTANT && - returnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back( - GradientUtils::getShadowType(FTy->getReturnType(), width)); - } else { - RetTypes.push_back(FTy->getReturnType()); - } - } else if (returnValue == ReturnType::ArgsWithTwoReturns || - returnValue == ReturnType::TwoReturns) { - RetTypes.push_back(FTy->getReturnType()); - if (returnType != DIFFE_TYPE::CONSTANT && - returnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back( - GradientUtils::getShadowType(FTy->getReturnType(), width)); - } else { - RetTypes.push_back(FTy->getReturnType()); - } - } - SmallVector ArgTypes; - - // The user might be deleting arguments to the function by specifying them in - // the VMap. If so, we need to not add the arguments to the arg ty vector - unsigned argno = 0; - - for (auto &I : FTy->params()) { - ArgTypes.push_back(I); - if (constant_args[argno] == DIFFE_TYPE::DUP_ARG || - constant_args[argno] == DIFFE_TYPE::DUP_NONEED) { - ArgTypes.push_back(GradientUtils::getShadowType(I, width)); - } else if (constant_args[argno] == DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(GradientUtils::getShadowType(I, width)); - } - ++argno; - } - - if (diffeReturnArg) { - assert(!FTy->getReturnType()->isVoidTy()); - ArgTypes.push_back( - GradientUtils::getShadowType(FTy->getReturnType(), width)); - } - if (additionalArg) { - ArgTypes.push_back(additionalArg); - } - Type *RetType = StructType::get(FTy->getContext(), RetTypes); - if (returnValue == ReturnType::TapeAndTwoReturns || - returnValue == ReturnType::TapeAndReturn || - returnValue == ReturnType::Tape) { - RetTypes.clear(); - RetTypes.push_back(getDefaultAnonymousTapeType(FTy->getContext())); - if (returnValue == ReturnType::TapeAndTwoReturns) { - RetTypes.push_back(FTy->getReturnType()); - RetTypes.push_back( - GradientUtils::getShadowType(FTy->getReturnType(), width)); - } else if (returnValue == ReturnType::TapeAndReturn) { - if (returnType != DIFFE_TYPE::CONSTANT && - returnType != DIFFE_TYPE::OUT_DIFF) - RetTypes.push_back( - GradientUtils::getShadowType(FTy->getReturnType(), width)); - else - RetTypes.push_back(FTy->getReturnType()); - } - RetType = StructType::get(FTy->getContext(), RetTypes); - } else if (returnValue == ReturnType::Return) { - assert(RetTypes.size() == 1); - RetType = RetTypes[0]; - } else if (returnValue == ReturnType::TwoReturns) { - assert(RetTypes.size() == 2); - } - - bool noReturn = RetTypes.size() == 0; - if (noReturn) - RetType = Type::getVoidTy(RetType->getContext()); - - // Create a new function type... - return FunctionType::get(RetType, ArgTypes, FTy->isVarArg()); -} - -Function *PreProcessCache::CloneFunctionWithReturns( - DerivativeMode mode, unsigned width, Function *&F, - ValueToValueMapTy &ptrInputs, ArrayRef constant_args, - SmallPtrSetImpl &constants, SmallPtrSetImpl &nonconstant, - SmallPtrSetImpl &returnvals, ReturnType returnValue, - DIFFE_TYPE returnType, const Twine &name, - llvm::ValueMap *VMapO, - bool diffeReturnArg, llvm::Type *additionalArg) { - if (!F->empty()) - F = preprocessForClone(F, mode); - llvm::ValueToValueMapTy VMap; - llvm::FunctionType *FTy = getFunctionTypeForClone( - F->getFunctionType(), mode, width, additionalArg, constant_args, - diffeReturnArg, returnValue, returnType); - - for (BasicBlock &BB : *F) { - if (auto ri = dyn_cast(BB.getTerminator())) { - if (auto rv = ri->getReturnValue()) { - returnvals.insert(rv); - } - } - } - - // Create the new function... - Function *NewF = Function::Create(FTy, F->getLinkage(), name, F->getParent()); - if (diffeReturnArg) { - auto I = NewF->arg_end(); - I--; - if (additionalArg) - I--; - I->setName("differeturn"); - } - if (additionalArg) { - auto I = NewF->arg_end(); - I--; - I->setName("tapeArg"); - } - - { - unsigned ii = 0; - for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) { - VMap[i] = j; - ++j; - ++i; - if (constant_args[ii] == DIFFE_TYPE::DUP_ARG || - constant_args[ii] == DIFFE_TYPE::DUP_NONEED) { - ++j; - } - ++ii; - } - } - - // Loop over the arguments, copying the names of the mapped arguments over... - Function::arg_iterator DestI = NewF->arg_begin(); - - for (const Argument &I : F->args()) - if (VMap.count(&I) == 0) { // Is this argument preserved? - DestI->setName(I.getName()); // Copy the name over... - VMap[&I] = &*DestI++; // Add mapping to VMap - } - SmallVector Returns; - if (!F->empty()) { - CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); - } - if (NewF->empty()) { - auto entry = BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> B(entry); - B.CreateUnreachable(); - } - CloneOrigin[NewF] = F; - if (VMapO) { - for (const auto &data : VMap) - VMapO->insert(std::pair( - data.first, (llvm::Value *)data.second)); - VMapO->getMDMap() = VMap.getMDMap(); - } - - for (auto attr : {"enzyme_ta_norecur", "frame-pointer"}) - if (F->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) { - NewF->addAttribute( - AttributeList::FunctionIndex, - F->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); - } - - for (auto attr : - {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) - if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) { - NewF->addAttribute( - AttributeList::ReturnIndex, - F->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); - } - - bool hasPtrInput = false; - unsigned ii = 0, jj = 0; - - for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) { - if (F->hasParamAttribute(ii, Attribute::StructRet)) { - NewF->addParamAttr(jj, Attribute::get(F->getContext(), "enzyme_sret")); - // TODO - // NewF->addParamAttr( - // jj, - // Attribute::get( - // F->getContext(), Attribute::AttrKind::ElementType, - // F->getParamAttribute(ii, - // Attribute::StructRet).getValueAsType())); - } - if (F->getAttributes().hasParamAttr(ii, "enzymejl_returnRoots")) { - NewF->addParamAttr( - jj, F->getAttributes().getParamAttr(ii, "enzymejl_returnRoots")); - // TODO - // NewF->addParamAttr(jj, F->getParamAttribute(ii, - // Attribute::ElementType)); - } - for (auto attr : - {"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"}) - if (F->getAttributes().hasParamAttr(ii, attr)) { - NewF->addParamAttr(jj, F->getAttributes().getParamAttr(ii, attr)); - for (auto ty : PrimalParamAttrsToPreserve) - if (F->getAttributes().hasParamAttr(ii, ty)) { - auto attr = F->getAttributes().getParamAttr(ii, ty); - NewF->addParamAttr(jj, attr); - } - } - if (constant_args[ii] == DIFFE_TYPE::CONSTANT) { - if (!i->hasByValAttr()) - constants.insert(i); - if (EnzymePrintActivity) - llvm::errs() << "in new function " << NewF->getName() - << " constant arg " << *j << "\n"; - } else { - nonconstant.insert(i); - if (EnzymePrintActivity) - llvm::errs() << "in new function " << NewF->getName() - << " nonconstant arg " << *j << "\n"; - } - - // Always remove nonnull/noundef since the caller may choose to pass - // undef as an arg if provably it will not be used in the reverse pass - if (constant_args[ii] == DIFFE_TYPE::DUP_NONEED || - mode == DerivativeMode::ReverseModeGradient) { - if (F->hasParamAttribute(ii, Attribute::NonNull)) { - NewF->removeParamAttr(jj, Attribute::NonNull); - } - if (F->hasParamAttribute(ii, Attribute::NoUndef)) { - NewF->removeParamAttr(jj, Attribute::NoUndef); - } - } - - if (constant_args[ii] == DIFFE_TYPE::DUP_ARG || - constant_args[ii] == DIFFE_TYPE::DUP_NONEED) { - hasPtrInput = true; - ptrInputs[i] = (j + 1); - // TODO: find a way to keep the attributes in vector mode. - if (width == 1) - for (auto ty : ShadowParamAttrsToPreserve) - if (F->getAttributes().hasParamAttr(ii, ty)) { - auto attr = F->getAttributes().getParamAttr(ii, ty); - NewF->addParamAttr(jj + 1, attr); - } - - for (auto attr : - {"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"}) - if (F->getAttributes().hasParamAttr(ii, attr)) { - if (width == 1) - NewF->addParamAttr(jj + 1, - F->getAttributes().getParamAttr(ii, attr)); - } - - if (F->getAttributes().hasParamAttr(ii, "enzymejl_returnRoots")) { - if (width == 1) { - NewF->addParamAttr(jj + 1, F->getAttributes().getParamAttr( - ii, "enzymejl_returnRoots")); - } else { - NewF->addParamAttr(jj + 1, Attribute::get(F->getContext(), - "enzymejl_returnRoots_v")); - } - // TODO - // NewF->addParamAttr(jj + 1, - // F->getParamAttribute(ii, - // Attribute::ElementType)); - } - - if (F->hasParamAttribute(ii, Attribute::StructRet)) { - if (width == 1) { - NewF->addParamAttr(jj + 1, - Attribute::get(F->getContext(), "enzyme_sret")); - // TODO - // NewF->addParamAttr( - // jj + 1, - // Attribute::get(F->getContext(), - // Attribute::AttrKind::ElementType, - // F->getParamAttribute(ii, - // Attribute::StructRet) - // .getValueAsType())); - } else { - NewF->addParamAttr(jj + 1, - Attribute::get(F->getContext(), "enzyme_sret_v")); - // TODO - // NewF->addParamAttr( - // jj + 1, - // Attribute::get(F->getContext(), - // Attribute::AttrKind::ElementType, - // F->getParamAttribute(ii, - // Attribute::StructRet) - // .getValueAsType())); - } - } - - j->setName(i->getName()); - ++j; - j->setName(i->getName() + "'"); - nonconstant.insert(j); - ++j; - jj += 2; - - ++i; - - } else { - j->setName(i->getName()); - ++j; - ++jj; - ++i; - } - ++ii; - } - - if (hasPtrInput && (mode == DerivativeMode::ReverseModeCombined || - mode == DerivativeMode::ReverseModeGradient)) { - if (NewF->hasFnAttribute(Attribute::ReadOnly)) { - NewF->removeFnAttr(Attribute::ReadOnly); - } -#if LLVM_VERSION_MAJOR >= 16 - auto eff = NewF->getMemoryEffects(); - for (auto loc : MemoryEffects::locations()) { - if (loc == MemoryEffects::Location::InaccessibleMem) - continue; - auto mr = eff.getModRef(loc); - if (isModSet(mr)) - eff |= MemoryEffects(loc, ModRefInfo::Ref); - if (isRefSet(mr)) - eff |= MemoryEffects(loc, ModRefInfo::Mod); - } - NewF->setMemoryEffects(eff); -#endif - } - NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - if (EnzymeAlwaysInlineDiff) - NewF->addFnAttr(Attribute::AlwaysInline); - assert(NewF->hasLocalLinkage()); - - return NewF; -} - -void CoaleseTrivialMallocs(Function &F, DominatorTree &DT) { - std::map>> - LegalMallocs; - - std::map> frees; - for (BasicBlock &BB : F) { - for (Instruction &I : BB) { - if (auto CI = dyn_cast(&I)) { - if (auto F2 = CI->getCalledFunction()) { - if (F2->getName() == "free") { - if (auto MD = hasMetadata(CI, "enzyme_cache_free")) { - Metadata *op = MD->getOperand(0); - frees[op].push_back(CI); - } - } - } - } - } - } - - for (BasicBlock &BB : F) { - for (Instruction &I : BB) { - if (auto CI = dyn_cast(&I)) { - if (auto F = CI->getCalledFunction()) { - if (F->getName() == "malloc") { - CallInst *freeCall = nullptr; - for (auto U : CI->users()) { - if (auto CI2 = dyn_cast(U)) { - if (auto F2 = CI2->getCalledFunction()) { - if (F2->getName() == "free") { - if (DT.dominates(CI, CI2)) { - freeCall = CI2; - break; - } - } - } - } - } - if (!freeCall) { - if (auto MD = hasMetadata(CI, "enzyme_cache_alloc")) { - Metadata *op = MD->getOperand(0); - if (frees[op].size() == 1) - freeCall = frees[op][0]; - } - } - if (freeCall) - LegalMallocs[&BB].emplace_back(CI, freeCall); - } - } - } - } - } - for (auto &pair : LegalMallocs) { - if (pair.second.size() < 2) - continue; - CallInst *First = pair.second[0].first; - for (auto &z : pair.second) { - if (!DT.dominates(First, z.first)) - First = z.first; - } - bool legal = true; - for (auto &z : pair.second) { - if (auto inst = dyn_cast(z.first->getArgOperand(0))) - if (!DT.dominates(inst, First)) - legal = false; - } - if (!legal) - continue; - IRBuilder<> B(First); - Value *Size = First->getArgOperand(0); - for (auto &z : pair.second) { - if (z.first == First) - continue; - Size = B.CreateAdd( - B.CreateOr(B.CreateSub(Size, ConstantInt::get(Size->getType(), 1)), - ConstantInt::get(Size->getType(), 15)), - ConstantInt::get(Size->getType(), 1)); - z.second->eraseFromParent(); - IRBuilder<> B2(z.first); - Value *gepPtr = B2.CreateInBoundsGEP(Type::getInt8Ty(First->getContext()), - First, Size); - z.first->replaceAllUsesWith(gepPtr); - Size = B.CreateAdd(Size, z.first->getArgOperand(0)); - z.first->eraseFromParent(); - } - auto NewMalloc = - cast(B.CreateCall(First->getCalledFunction(), Size)); - NewMalloc->copyIRFlags(First); - NewMalloc->setMetadata("enzyme_cache_alloc", - hasMetadata(First, "enzyme_cache_alloc")); - First->replaceAllUsesWith(NewMalloc); - First->eraseFromParent(); - } -} - -void SelectOptimization(Function *F) { - DominatorTree DT(*F); - for (auto &BB : *F) { - if (auto BI = dyn_cast(BB.getTerminator())) { - if (BI->isConditional()) { - for (auto &I : BB) { - if (auto SI = dyn_cast(&I)) { - if (SI->getCondition() == BI->getCondition()) { - for (Value::use_iterator UI = SI->use_begin(), E = SI->use_end(); - UI != E;) { - Use &U = *UI; - ++UI; - if (DT.dominates(BasicBlockEdge(&BB, BI->getSuccessor(0)), U)) - U.set(SI->getTrueValue()); - else if (DT.dominates(BasicBlockEdge(&BB, BI->getSuccessor(1)), - U)) - U.set(SI->getFalseValue()); - } - } - } - } - } - } - } -} - -void ReplaceFunctionImplementation(Module &M) { - for (Function &Impl : M) { - for (auto attr : {"implements", "implements2"}) { - if (!Impl.hasFnAttribute(attr)) - continue; - const Attribute &A = Impl.getFnAttribute(attr); - - const StringRef SpecificationName = A.getValueAsString(); - Function *Specification = M.getFunction(SpecificationName); - if (!Specification) { - LLVM_DEBUG(dbgs() << "Found implementation '" << Impl.getName() - << "' but no matching specification with name '" - << SpecificationName - << "', potentially inlined and/or eliminated.\n"); - continue; - } - LLVM_DEBUG(dbgs() << "Replace specification '" << Specification->getName() - << "' with implementation '" << Impl.getName() - << "'\n"); - - for (auto I = Specification->use_begin(), UE = Specification->use_end(); - I != UE;) { - auto &use = *I; - ++I; - auto cext = ConstantExpr::getBitCast(&Impl, Specification->getType()); - if (cast(use.getUser())->getParent()->getParent() == &Impl) - continue; - use.set(cext); - if (auto CI = dyn_cast(use.getUser())) { - if (CI->getCalledOperand() == cext || - CI->getCalledFunction() == &Impl) { - CI->setCallingConv(Impl.getCallingConv()); - } - } - } - } - } -} - -void PreProcessCache::optimizeIntermediate(Function *F) { - PreservedAnalyses PA; - PA = PromotePass().run(*F, FAM); - FAM.invalidate(*F, PA); -#if !defined(FLANG) - PA = GVNPass().run(*F, FAM); -#else - PA = GVN().run(*F, FAM); -#endif - FAM.invalidate(*F, PA); -#if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG) - PA = SROAPass(llvm::SROAOptions::PreserveCFG).run(*F, FAM); -#elif !defined(FLANG) - PA = SROAPass().run(*F, FAM); -#else - PA = SROA().run(*F, FAM); -#endif - FAM.invalidate(*F, PA); - - if (EnzymeSelectOpt) { - SimplifyCFGOptions scfgo; - PA = SimplifyCFGPass(scfgo).run(*F, FAM); - FAM.invalidate(*F, PA); - PA = CorrelatedValuePropagationPass().run(*F, FAM); - FAM.invalidate(*F, PA); - SelectOptimization(F); - } - // EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM); - - if (EnzymeCoalese) - CoaleseTrivialMallocs(*F, FAM.getResult(*F)); - - ReplaceFunctionImplementation(*F->getParent()); - - { - PreservedAnalyses PA; - FAM.invalidate(*F, PA); - } - - OptimizationLevel Level = OptimizationLevel::O0; - - switch (EnzymePostOptLevel) { - default: - case 0: - Level = OptimizationLevel::O0; - break; - case 1: - Level = OptimizationLevel::O1; - break; - case 2: - Level = OptimizationLevel::O2; - break; - case 3: - Level = OptimizationLevel::O3; - break; - } - if (Level != OptimizationLevel::O0) { - PassBuilder PB; - FunctionPassManager FPM = - PB.buildFunctionSimplificationPipeline(Level, ThinOrFullLTOPhase::None); - PA = FPM.run(*F, FAM); - FAM.invalidate(*F, PA); - } - - // TODO actually run post optimizations. -} - -void PreProcessCache::clear() { - LAM.clear(); - FAM.clear(); - MAM.clear(); - cache.clear(); -} - -// Returns if a is guaranteed to be equivalent to not b -static bool isNot(Value *a, Value *b) { - // cmp pred, a, b and cmp inverse(pred), a, b - if (auto I1 = dyn_cast(a)) - if (auto I2 = dyn_cast(b)) - if (I1->getOperand(0) == I2->getOperand(0) && - I1->getOperand(1) == I2->getOperand(1) && - I1->getPredicate() == I2->getInversePredicate()) - return true; - // a := xor true, b - if (auto I = dyn_cast(a)) - if (I->getOpcode() == Instruction::Xor) - for (int i = 0; i < 2; i++) { - if (I->getOperand(i) == b) - if (auto CI = dyn_cast(I->getOperand(1 - i))) -#if LLVM_VERSION_MAJOR > 16 - if (CI->getValue().isAllOnes()) -#else - if (CI->getValue().isAllOnesValue()) -#endif - return true; - } - // b := xor true, a - if (auto I = dyn_cast(b)) - if (I->getOpcode() == Instruction::Xor) - for (int i = 0; i < 2; i++) { - if (I->getOperand(i) == a) - if (auto CI = dyn_cast(I->getOperand(1 - i))) -#if LLVM_VERSION_MAJOR > 16 - if (CI->getValue().isAllOnes()) -#else - if (CI->getValue().isAllOnesValue()) -#endif - return true; - } - return false; -} - -struct compare_insts { -public: - DominatorTree &DT; - LoopInfo &LI; - compare_insts(DominatorTree &DT, LoopInfo &LI) : DT(DT), LI(LI) {} - - // return true if A appears later than B. - bool operator()(Instruction *A, Instruction *B) const { - if (A == B) { - return false; - } - if (A->getParent() == B->getParent()) { - return !A->comesBefore(B); - } - auto AB = A->getParent(); - auto BB = B->getParent(); - assert(AB->getParent() == BB->getParent()); - - for (auto prev = BB->getPrevNode(); prev; prev = prev->getPrevNode()) { - if (prev == AB) - return false; - } - return true; - } -}; - -class DominatorOrderSet : public std::set { -public: - DominatorOrderSet(DominatorTree &DT, LoopInfo &LI) - : std::set(compare_insts(DT, LI)) {} - bool contains(Instruction *I) const { - auto __i = find(I); - return __i != end(); - } - void remove(Instruction *I) { - auto __i = find(I); - assert(__i != end()); - erase(__i); - } - Instruction *pop_back_val() { - auto back = end(); - back--; - auto v = *back; - erase(back); - return v; - } -}; - -bool directlySparse(Value *z) { - if (isa(z)) - return true; - if (isa(z)) - return true; - if (isa(z)) - return true; - if (isa(z)) - return true; - if (auto SI = dyn_cast(z)) { - if (auto CI = dyn_cast(SI->getTrueValue())) - if (CI->isZero()) - return true; - if (auto CI = dyn_cast(SI->getFalseValue())) - if (CI->isZero()) - return true; - } - return false; -} - -typedef DominatorOrderSet QueueType; - -Function *getProductIntrinsic(llvm::Module &M, llvm::Type *T) { - std::string name = "__enzyme_product."; - if (T->isFloatTy()) - name += "f32"; - else if (T->isDoubleTy()) - name += "f64"; - else if (T->isIntegerTy()) - name += "i" + std::to_string(cast(T)->getBitWidth()); - else - assert(0); - auto FT = llvm::FunctionType::get(T, {}, true); - AttributeList AL; - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::ReadNone); - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::NoUnwind); - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::NoFree); - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::NoSync); - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::WillReturn); - return cast(M.getOrInsertFunction(name, FT, AL).getCallee()); -} - -Function *getSumIntrinsic(llvm::Module &M, llvm::Type *T) { - std::string name = "__enzyme_sum."; - if (T->isFloatTy()) - name += "f32"; - else if (T->isDoubleTy()) - name += "f64"; - else if (T->isIntegerTy()) - name += "i" + std::to_string(cast(T)->getBitWidth()); - else - assert(0); - auto FT = llvm::FunctionType::get(T, {}, true); - AttributeList AL; - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::ReadNone); - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::NoUnwind); - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::NoFree); - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::NoSync); - AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, - Attribute::WillReturn); - return cast(M.getOrInsertFunction(name, FT, AL).getCallee()); -} - -CallInst *isProduct(llvm::Value *v) { - if (auto prod = dyn_cast(v)) - if (auto F = getFunctionFromCall(prod)) - if (startsWith(F->getName(), "__enzyme_product")) - return prod; - return nullptr; -} - -CallInst *isSum(llvm::Value *v) { - if (auto prod = dyn_cast(v)) - if (auto F = getFunctionFromCall(prod)) - if (startsWith(F->getName(), "__enzyme_sum")) - return prod; - return nullptr; -} - -SmallVector callOperands(llvm::CallBase *CB) { - return SmallVector(CB->args().begin(), CB->args().end()); -} - -bool guaranteedDataDependent(Value *z) { - if (isa(z)) - return true; - if (isa(z)) - return true; - if (auto BO = dyn_cast(z)) - return guaranteedDataDependent(BO->getOperand(0)) && - guaranteedDataDependent(BO->getOperand(1)); - if (auto C = dyn_cast(z)) - return guaranteedDataDependent(C->getOperand(0)); - if (auto S = isSum(z)) { - for (auto op : callOperands(S)) - if (guaranteedDataDependent(op)) - return true; - return false; - } - if (auto S = isProduct(z)) { - for (auto op : callOperands(S)) - if (!guaranteedDataDependent(op)) - return false; - return true; - } - if (auto II = dyn_cast(z)) { - switch (II->getIntrinsicID()) { - case Intrinsic::sqrt: - case Intrinsic::sin: - case Intrinsic::cos: -#if LLVM_VERSION_MAJOR >= 19 - case Intrinsic::sinh: - case Intrinsic::cosh: - case Intrinsic::tanh: -#endif - return guaranteedDataDependent(II->getArgOperand(0)); - default: - break; - } - } - return false; -} - -std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, - QueueType &Q, DominatorTree &DT, - ScalarEvolution &SE, LoopInfo &LI, - const DataLayout &DL) { - auto push = [&](llvm::Value *V) { - if (V == cur) - return V; - assert(V); - if (auto I = dyn_cast(V)) { - Q.insert(I); - for (auto U : I->users()) { - if (auto I2 = dyn_cast(U)) { - if (I2 == cur) - continue; - Q.insert(I2); - } - } - } - return V; - }; - auto pushcse = [&](llvm::Value *V) -> llvm::Value * { - if (auto I = dyn_cast(V)) { - for (size_t i = 0; i < I->getNumOperands(); i++) { - if (auto I2 = dyn_cast(I->getOperand(i))) { - Instruction *candidate = nullptr; - for (auto U : I2->users()) { - candidate = dyn_cast(U); - if (!candidate) - continue; - if (candidate == I && candidate->getType() != I->getType()) { - candidate = nullptr; - continue; - } - bool isSame = candidate->isIdenticalTo(I); - if (!isSame) { - if (auto P1 = isProduct(I)) - if (auto P2 = isProduct(I2)) { - std::multiset s1; - std::multiset s2; - for (auto &v : callOperands(P1)) - s1.insert(v); - for (auto &v : callOperands(P2)) - s2.insert(v); - isSame = s1 == s2; - } - if (auto P1 = isSum(I)) - if (auto P2 = isSum(I2)) { - std::multiset s1; - std::multiset s2; - for (auto &v : callOperands(P1)) - s1.insert(v); - for (auto &v : callOperands(P2)) - s2.insert(v); - isSame = s1 == s2; - } - } - if (!isSame) { - candidate = nullptr; - continue; - } - - if (DT.dominates(candidate, I)) { - break; - } - candidate = nullptr; - } - if (candidate) { - I->eraseFromParent(); - return candidate; - } - break; - } - } - return push(I); - } - return V; - }; - auto replaceAndErase = [&](llvm::Instruction *I, llvm::Value *candidate) { - for (auto U : I->users()) - push(U); - I->replaceAllUsesWith(candidate); - push(candidate); - - SetVector operands; - for (size_t i = 0; i < I->getNumOperands(); i++) { - if (auto I2 = dyn_cast(I->getOperand(i))) { - if ((!I2->mayWriteToMemory() || - (isa(I2) && isReadOnly(cast(I2))))) - operands.insert(I2); - } - } - if (Q.contains(I)) { - Q.remove(I); - } - assert(!Q.contains(I)); - I->eraseFromParent(); - for (auto op : operands) - if (op->getNumUses() == 0) { - if (Q.contains(op)) - Q.remove(op); - op->eraseFromParent(); - } - }; - if (!cur->getType()->isVoidTy() && - (!cur->mayWriteToMemory() || - (isa(cur) && isReadOnly(cast(cur))))) { - // DCE - if (cur->getNumUses() == 0) { - for (size_t i = 0; i < cur->getNumOperands(); i++) - push(cur->getOperand(i)); - assert(!Q.contains(cur)); - cur->eraseFromParent(); - return "DCE"; - } - // CSE - { - for (size_t i = 0; i < cur->getNumOperands(); i++) { - if (auto I = dyn_cast(cur->getOperand(i))) { - Instruction *candidate = nullptr; - bool reverse = false; - for (auto U : I->users()) { - candidate = dyn_cast(U); - if (!candidate) - continue; - if (candidate == cur && candidate->getType() != cur->getType()) { - candidate = nullptr; - continue; - } - bool isSame = candidate->isIdenticalTo(cur); - if (!isSame) { - if (auto P1 = isProduct(candidate)) - if (auto P2 = isProduct(cur)) { - std::multiset s1; - std::multiset s2; - for (auto &v : callOperands(P1)) - s1.insert(v); - for (auto &v : callOperands(P2)) - s2.insert(v); - isSame = s1 == s2; - } - if (auto P1 = isSum(candidate)) - if (auto P2 = isSum(cur)) { - std::multiset s1; - std::multiset s2; - for (auto &v : callOperands(P1)) - s1.insert(v); - for (auto &v : callOperands(P2)) - s2.insert(v); - isSame = s1 == s2; - } - } - - if (!isSame) { - candidate = nullptr; - continue; - } - - if (DT.dominates(candidate, cur)) { - break; - } else if (DT.dominates(cur, candidate)) { - reverse = true; - break; - } - candidate = nullptr; - } - if (candidate) { - if (reverse) { - if (Q.contains(candidate)) - Q.remove(candidate); - auto tmp = candidate; - candidate = cur; - cur = tmp; - } - replaceAndErase(cur, candidate); - return "CSE"; - } - break; - } - } - } - } - - if (auto SI = dyn_cast(cur)) - if (auto CI = dyn_cast(SI->getCondition())) { - if (CI->isOne()) { - replaceAndErase(cur, SI->getTrueValue()); - return "SelectToTrue"; - } else { - replaceAndErase(cur, SI->getFalseValue()); - return "SelectToFalse"; - } - } - if (cur->getOpcode() == Instruction::Or) { - for (int i = 0; i < 2; i++) { - if (auto C = dyn_cast(cur->getOperand(i))) { - // or a, 0 -> a - if (C->isZero()) { - replaceAndErase(cur, cur->getOperand(1 - i)); - return "OrZero"; - } - // or a, 1 -> 1 - if (C->isOne() && cur->getType()->isIntegerTy(1)) { - replaceAndErase(cur, C); - return "OrOne"; - } - } - } - } - if (cur->getOpcode() == Instruction::And) { - for (int i = 0; i < 2; i++) { - if (auto C = dyn_cast(cur->getOperand(i))) { - // and a, 1 -> a - if (C->isOne() && cur->getType()->isIntegerTy(1)) { - replaceAndErase(cur, cur->getOperand(1 - i)); - return "AndOne"; - } - // and a, 0 -> 0 - if (C->isZero()) { - replaceAndErase(cur, C); - return "AndZero"; - } - } - } - } - - IRBuilder<> B(cur); - if (auto CI = dyn_cast(cur)) - if (auto C = dyn_cast(CI->getOperand(0))) { - replaceAndErase( - cur, cast(B.CreateCast(CI->getOpcode(), C, CI->getType()))); - return "CastConstProp"; - } - std::function replace = [&](Value *val, - Value *orig, - Value *with) { - if (val == orig) { - return with; - } - if (isNot(val, orig)) { - return pushcse(B.CreateNot(with)); - } - if (isa(val)) - return val; - - if (auto I = dyn_cast(val)) { - if (I->mayWriteToMemory() && - !(isa(I) && isReadOnly(cast(I)))) - return val; - - if (I->getOpcode() == Instruction::Add) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse(B.CreateAdd(lhs, rhs, "sel." + I->getName(), - I->hasNoUnsignedWrap(), - I->hasNoSignedWrap())); - } - - if (I->getOpcode() == Instruction::Sub) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse(B.CreateSub(lhs, rhs, "sel." + I->getName(), - I->hasNoUnsignedWrap(), - I->hasNoSignedWrap())); - } - - if (I->getOpcode() == Instruction::Mul) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse(B.CreateMul(lhs, rhs, "sel." + I->getName(), - I->hasNoUnsignedWrap(), - I->hasNoSignedWrap())); - } - - if (I->getOpcode() == Instruction::And) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse(B.CreateAnd(lhs, rhs, "sel." + I->getName())); - } - - if (I->getOpcode() == Instruction::Or) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse(B.CreateOr(lhs, rhs, "sel." + I->getName())); - } - - if (I->getOpcode() == Instruction::Xor) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse(B.CreateXor(lhs, rhs, "sel." + I->getName())); - } - - if (I->getOpcode() == Instruction::FAdd) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse(B.CreateFAddFMF(lhs, rhs, I, "sel." + I->getName())); - } - - if (I->getOpcode() == Instruction::FSub) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse(B.CreateFSubFMF(lhs, rhs, I, "sel." + I->getName())); - } - - if (I->getOpcode() == Instruction::FMul) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse(B.CreateFMulFMF(lhs, rhs, I, "sel." + I->getName())); - } - - if (I->getOpcode() == Instruction::ZExt) { - Value *op = replace(I->getOperand(0), orig, with); - if (op == I->getOperand(0)) - return val; - push(I); - return pushcse(B.CreateZExt(op, I->getType(), "sel." + I->getName())); - } - - if (I->getOpcode() == Instruction::SExt) { - Value *op = replace(I->getOperand(0), orig, with); - if (op == I->getOperand(0)) - return val; - push(I); - return pushcse(B.CreateSExt(op, I->getType(), "sel." + I->getName())); - } - - if (I->getOpcode() == Instruction::UIToFP) { - Value *op = replace(I->getOperand(0), orig, with); - if (op == I->getOperand(0)) - return val; - push(I); - return pushcse(B.CreateUIToFP(op, I->getType(), "sel." + I->getName())); - } - - if (I->getOpcode() == Instruction::SIToFP) { - Value *op = replace(I->getOperand(0), orig, with); - if (op == I->getOperand(0)) - return val; - push(I); - return pushcse(B.CreateSIToFP(op, I->getType(), "sel." + I->getName())); - } - - if (auto CI = dyn_cast(I)) { - Value *lhs = replace(I->getOperand(0), orig, with); - Value *rhs = replace(I->getOperand(1), orig, with); - if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) - return val; - push(I); - return pushcse( - B.CreateCmp(CI->getPredicate(), lhs, rhs, "sel." + I->getName())); - } - - if (auto SI = dyn_cast(I)) { - Value *cond = replace(SI->getCondition(), orig, with); - Value *tval = replace(SI->getTrueValue(), orig, with); - Value *fval = replace(SI->getFalseValue(), orig, with); - if (cond == SI->getCondition() && tval == SI->getTrueValue() && - fval == SI->getFalseValue()) - return val; - push(I); - if (auto CI = dyn_cast(cond)) { - if (CI->isOne()) - return tval; - else - return fval; - } - return pushcse(B.CreateSelect(cond, tval, fval, "sel." + I->getName())); - } - - if (isProduct(I) || isSum(I)) { - auto C = cast(I); - auto ops = callOperands(C); - bool changed = false; - for (auto &op : ops) { - auto next = replace(op, orig, with); - if (next != op) { - changed = true; - op = next; - } - } - if (!changed) - return (Value *)I; - push(I); - pushcse( - B.CreateCall(getFunctionFromCall(C), ops, "sel." + I->getName())); - } - } - return val; - }; - - if (auto II = dyn_cast(cur)) - if (II->getIntrinsicID() == Intrinsic::fmuladd || - II->getIntrinsicID() == Intrinsic::fma) { - B.setFastMathFlags(getFast()); - auto mul = pushcse(B.CreateFMul(II->getOperand(0), II->getOperand(1))); - auto add = pushcse(B.CreateFAdd(mul, II->getOperand(2))); - replaceAndErase(cur, add); - return "FMulAddExpand"; - } - - if (auto BO = dyn_cast(cur)) { - if (BO->getOpcode() == Instruction::FMul && BO->isFast()) { - Value *args[2] = {BO->getOperand(0), BO->getOperand(1)}; - auto mul = pushcse( - B.CreateCall(getProductIntrinsic(*F.getParent(), BO->getType()), args, - cur->getName())); - replaceAndErase(cur, mul); - return "FMulToProduct"; - } - if (BO->getOpcode() == Instruction::FDiv && BO->isFast()) { - auto c0 = dyn_cast(BO->getOperand(0)); - if (!c0 || !c0->isExactlyValue(1.0)) { - B.setFastMathFlags(getFast()); - auto div = pushcse(B.CreateFDivFMF(ConstantFP::get(BO->getType(), 1.0), - BO->getOperand(1), BO)); - auto mul = pushcse( - B.CreateFMulFMF(BO->getOperand(0), div, BO, cur->getName())); - replaceAndErase(cur, mul); - return "FDivToFMul"; - } - } - if (BO->getOpcode() == Instruction::FAdd && BO->isFast()) { - Value *args[2] = {BO->getOperand(0), BO->getOperand(1)}; - auto mul = pushcse( - B.CreateCall(getSumIntrinsic(*F.getParent(), BO->getType()), args)); - replaceAndErase(cur, mul); - return "FAddToSum"; - } - if (BO->getOpcode() == Instruction::FSub && BO->isFast()) { - B.setFastMathFlags(getFast()); - Value *args[2] = {BO->getOperand(0), - pushcse(B.CreateFNeg(BO->getOperand(1)))}; - auto mul = - pushcse(B.CreateCall(getSumIntrinsic(*F.getParent(), BO->getType()), - args, cur->getName())); - replaceAndErase(cur, mul); - return "FAddToSum"; - } - } - if (cur->getOpcode() == Instruction::FNeg) { - B.setFastMathFlags(getFast()); - auto mul = - pushcse(B.CreateFMulFMF(ConstantFP::get(cur->getType(), -1.0), - cur->getOperand(0), cur, cur->getName())); - replaceAndErase(cur, mul); - return "FNegToMul"; - } - - if (auto SI = dyn_cast(cur)) { - if (auto tc = dyn_cast(SI->getTrueValue())) - if (auto fc = dyn_cast(SI->getFalseValue())) - if (fc->isZero()) { - if (tc->isExactlyValue(1.0)) { - auto res = - pushcse(B.CreateUIToFP(SI->getCondition(), tc->getType())); - replaceAndErase(cur, res); - return "SelToUIFP"; - } - if (tc->isExactlyValue(-1.0)) { - auto res = - pushcse(B.CreateSIToFP(SI->getCondition(), tc->getType())); - replaceAndErase(cur, res); - return "SelToSIFP"; - } - } - } - - if (auto P = isProduct(cur)) { - SmallVector operands; - std::optional constval; - bool changed = false; - for (auto &v : callOperands(P)) - - { - if (auto P2 = isProduct(v)) { - for (auto &v2 : callOperands(P2)) { - push(v2); - operands.push_back(v2); - } - push(P2); - changed = true; - continue; - } - if (auto C = dyn_cast(v)) { - if (C->isExactlyValue(1.0)) { - changed = true; - continue; - } - if (C->isZero()) { - replaceAndErase(cur, C); - return "ZeroProduct"; - } - if (!constval) { - constval = C->getValue(); - continue; - } - constval = (*constval) * C->getValue(); - changed = true; - continue; - } - if (auto op = dyn_cast(v)) { - if (auto tc = dyn_cast(op->getTrueValue())) - if (tc->isZero()) { - operands.push_back(pushcse(B.CreateUIToFP( - pushcse(B.CreateNot(op->getCondition())), op->getType()))); - operands.push_back(op->getFalseValue()); - changed = true; - continue; - } - if (auto tc = dyn_cast(op->getFalseValue())) - if (tc->isZero()) { - operands.push_back( - pushcse(B.CreateUIToFP(op->getCondition(), op->getType()))); - operands.push_back(op->getTrueValue()); - changed = true; - continue; - } - } - operands.push_back(v); - } - if (constval) - operands.push_back(ConstantFP::get(cur->getType(), *constval)); - - if (operands.size() == 0) { - replaceAndErase(cur, ConstantFP::get(cur->getType(), 1.0)); - return "EmptyProduct"; - } - if (operands.size() == 1) { - replaceAndErase(cur, operands[0]); - return "SingleProduct"; - } - if (changed) { - auto mul = pushcse( - B.CreateCall(getProductIntrinsic(*F.getParent(), cur->getType()), - operands, cur->getName())); - replaceAndErase(cur, mul); - return "ProductSimplification"; - } - } - - if (auto P = isSum(cur)) { - // map from operand, to number of counts - std::map operands; - std::optional constval; - bool changed = false; - for (auto &v : callOperands(P)) { - if (auto P2 = isSum(v)) { - for (auto &v2 : callOperands(P2)) { - push(v2); - operands[v2]++; - } - push(P2); - changed = true; - continue; - } - if (auto C = dyn_cast(v)) { - if (C->isExactlyValue(0.0)) { - changed = true; - continue; - } - if (!constval) { - constval = C->getValue(); - continue; - } - constval = (*constval) + C->getValue(); - changed = true; - continue; - } - operands[v]++; - } - if (constval) - operands[ConstantFP::get(cur->getType(), *constval)]++; - - if (operands.size() == 0) { - replaceAndErase(cur, ConstantFP::get(cur->getType(), 0.0)); - return "EmptySum"; - } - SmallVector args; - for (auto &pair : operands) { - if (pair.second == 1) { - args.push_back(pair.first); - continue; - } - changed = true; - Value *sargs[] = {pair.first, - ConstantFP::get(cur->getType(), (double)pair.second)}; - args.push_back(pushcse(B.CreateCall( - getProductIntrinsic(*F.getParent(), cur->getType()), sargs))); - } - if (args.size() == 1) { - replaceAndErase(cur, args[0]); - return "SingleSum"; - } - if (changed) { - auto sum = - pushcse(B.CreateCall(getSumIntrinsic(*F.getParent(), cur->getType()), - args, cur->getName())); - replaceAndErase(cur, sum); - return "SumSimplification"; - } - } - - if (auto P = isProduct(cur)) { - SmallVector operands; - SmallVector conditions; - for (auto &v : callOperands(P)) { - // z = uitofp i1 c to float -> select c, (prod withot z), 0 - if (auto op = dyn_cast(v)) { - if (op->getOperand(0)->getType()->isIntegerTy(1)) { - conditions.push_back(op->getOperand(0)); - continue; - } - } - // z = sitofp i1 c to float -> select c, (-prod withot z), 0 - if (auto op = dyn_cast(v)) { - if (op->getOperand(0)->getType()->isIntegerTy(1)) { - conditions.push_back(op->getOperand(0)); - operands.push_back(ConstantFP::get(cur->getType(), -1.0)); - continue; - } - } - if (auto op = dyn_cast(v)) { - if (auto tc = dyn_cast(op->getTrueValue())) - if (tc->isZero()) { - conditions.push_back(pushcse(B.CreateNot(op->getCondition()))); - operands.push_back(op->getFalseValue()); - continue; - } - if (auto tc = dyn_cast(op->getFalseValue())) - if (tc->isZero()) { - conditions.push_back(op->getCondition()); - operands.push_back(op->getTrueValue()); - continue; - } - } - operands.push_back(v); - } - - if (conditions.size()) { - auto mul = pushcse(B.CreateCall( - getProductIntrinsic(*F.getParent(), cur->getType()), operands)); - Value *condition = nullptr; - for (auto v : conditions) { - assert(v->getType()->isIntegerTy(1)); - if (condition == nullptr) { - condition = v; - continue; - } - condition = pushcse(B.CreateAnd(condition, v)); - } - auto zero = ConstantFP::get(cur->getType(), 0.0); - auto sel = pushcse(B.CreateSelect(condition, mul, zero, cur->getName())); - replaceAndErase(cur, sel); - return "ProductSelect"; - } - } - - // TODO - if (auto P = isSum(cur)) { - // whether negated - SmallVector, 1> conditions; - bool legal = true; - for (auto &v : callOperands(P)) { - // z = uitofp i1 c to float -> select c, (prod withot z), 0 - if (auto op = dyn_cast(v)) { - if (op->getOperand(0)->getType()->isIntegerTy(1)) { - conditions.emplace_back(op->getOperand(0), false); - continue; - } - } - // z = sitofp i1 c to float -> select c, (-prod withot z), 0 - if (auto op = dyn_cast(v)) { - if (op->getOperand(0)->getType()->isIntegerTy(1)) { - conditions.emplace_back(op->getOperand(0), false); - continue; - } - } - if (auto op = dyn_cast(v)) { - if (auto tc = dyn_cast(op->getTrueValue())) - if (tc->isZero()) { - conditions.emplace_back(op->getCondition(), true); - continue; - } - if (auto tc = dyn_cast(op->getFalseValue())) - if (tc->isZero()) { - conditions.emplace_back(op->getCondition(), false); - continue; - } - } - legal = false; - break; - } - Value *condition = nullptr; - if (legal) - for (size_t i = 0; i < conditions.size(); i++) { - size_t count = 0; - for (size_t j = 0; j < conditions.size(); j++) { - if (((conditions[i].first == conditions[j].first) && - (conditions[i].second == conditions[i].second)) || - ((isNot(conditions[i].first, conditions[j].first) && - (conditions[i].second != conditions[i].second)))) - count++; - } - if (count == conditions.size() && count > 1) { - condition = conditions[i].first; - if (conditions[i].second) - condition = pushcse(B.CreateNot(condition, "sumpnot")); - break; - } - } - - if (condition) { - - SmallVector operands; - for (auto &v : callOperands(P)) { - // z = uitofp i1 c to float -> select c, (prod withot z), 0 - if (auto op = dyn_cast(v)) { - if (op->getOperand(0)->getType()->isIntegerTy(1)) { - operands.push_back(ConstantFP::get(cur->getType(), 1.0)); - continue; - } - } - // z = sitofp i1 c to float -> select c, (-prod withot z), 0 - if (auto op = dyn_cast(v)) { - if (op->getOperand(0)->getType()->isIntegerTy(1)) { - operands.push_back(ConstantFP::get(cur->getType(), -1.0)); - continue; - } - } - if (auto op = dyn_cast(v)) { - if (auto tc = dyn_cast(op->getTrueValue())) - if (tc->isZero()) { - operands.push_back(op->getFalseValue()); - continue; - } - if (auto tc = dyn_cast(op->getFalseValue())) - if (tc->isZero()) { - operands.push_back(op->getTrueValue()); - continue; - } - } - llvm::errs() << " unhandled call op sumselect: " << *v << "\n"; - assert(0); - } - - if (conditions.size()) { - auto sum = pushcse(B.CreateCall( - getSumIntrinsic(*F.getParent(), cur->getType()), operands)); - auto zero = ConstantFP::get(cur->getType(), 0.0); - auto sel = - pushcse(B.CreateSelect(condition, sum, zero, cur->getName())); - replaceAndErase(cur, sel); - return "SumSelect"; - } - } - } - // (a1*b1) + (a1*c1) + (a1*d1 ) + ... -> a1 * (b1 + c1 + d1 + ...) - if (auto S = isSum(cur)) { - SmallVector allOps; - auto combine = [](const SmallVector &lhs, - SmallVector rhs) { - SmallVector out; - for (auto v : lhs) { - bool seen = false; - for (auto &v2 : rhs) { - if (v == v2) { - v2 = nullptr; - seen = true; - break; - } - } - if (seen) { - out.push_back(v); - } - } - return out; - }; - auto subtract = [](SmallVector lhs, - const SmallVector &rhs) { - for (auto v : rhs) { - auto found = find(lhs, v); - assert(found != lhs.end()); - lhs.erase(found); - } - return lhs; - }; - bool seen = false; - bool legal = true; - for (auto op : callOperands(S)) { - auto P = isProduct(op); - if (!P) { - legal = false; - break; - } - if (!seen) { - allOps = callOperands(P); - seen = true; - continue; - } - allOps = combine(allOps, callOperands(P)); - } - - if (legal && allOps.size() > 0) { - SmallVector operands; - for (auto op : callOperands(S)) { - auto P = isProduct(op); - push(op); - auto sub = subtract(callOperands(P), allOps); - auto newprod = pushcse(B.CreateCall( - getProductIntrinsic(*F.getParent(), S->getType()), sub)); - operands.push_back(newprod); - } - auto newsum = pushcse(B.CreateCall( - getSumIntrinsic(*F.getParent(), S->getType()), operands)); - allOps.push_back(newsum); - auto fprod = pushcse(B.CreateCall( - getProductIntrinsic(*F.getParent(), S->getType()), allOps)); - replaceAndErase(cur, fprod); - return "SumFactor"; - } - } - - /* - // add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1 - != c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2; - j++) if (auto c0 = dyn_cast(cur->getOperand(j))) if (auto cmp0 = - dyn_cast(c0->getOperand(0))) if (auto c1 = - dyn_cast(cur->getOperand(1-j))) if (auto cmp1 = - dyn_cast(c0->getOperand(0))) if (cmp0->getPredicate() == - ICmpInst::ICMP_EQ && cmp1->getPredicate() == ICmpInst::ICMP_EQ) - { - for (size_t i0 = 0; i0 < 2; i0++) - for (size_t i1 = 0; i1 < 2; i1++) - if (cmp0->getOperand(1 - i0) == cmp1->getOperand(1 - i1)) - auto e0 = SE.getSCEV(cmp0->getOperand(i0)); - auto e1 = SE.getSCEV(cmp1->getOperand(i1)); - auto m = SE.getMinusSCEV(e0, e1, SCEV::NoWrapMask); - if (auto C = dyn_cast(m)) { - // if c1 == c2 don't need the and they are equivalent - if (C->getValue()->isZero()) { - } else { - auto sel0 = pushcse(B.CreateSelect(cmp0, - ConstantInt::get(cur->getType(), isa(cmp0) ? 1 : -1), - ConstantInt::get(cur->getType(), 0)); - // if non one constant they must be distinct. - replaceAndErase(cur, - ConstantInt::getFalse(cur->getContext())); - return "AndNEExpr"; - } - } - } - } - */ - - if (auto fcmp = dyn_cast(cur)) { - auto predicate = fcmp->getPredicate(); - if (predicate == FCmpInst::FCMP_OEQ || predicate == FCmpInst::FCMP_UEQ || - predicate == FCmpInst::FCMP_UNE || predicate == FCmpInst::FCMP_ONE) { - for (int i = 0; i < 2; i++) - if (auto C = dyn_cast(fcmp->getOperand(i))) { - if (C->isZero()) { - // (a1*a2*...an) == 0 -> (a1 == 0) || (a2 == 0) || ... (a2 == 0) - // (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 == - // 0) - // ] - if (auto P = isProduct(fcmp->getOperand(1 - i))) { - Value *res = nullptr; - - auto eq_predicate = predicate; - if (predicate == FCmpInst::FCMP_UNE || - predicate == FCmpInst::FCMP_ONE) - eq_predicate = fcmp->getInversePredicate(); - - for (auto &v : callOperands(P)) { - auto ncmp1 = pushcse(B.CreateFCmp(eq_predicate, v, C)); - if (!res) - res = ncmp1; - else - res = pushcse(B.CreateOr(res, ncmp1)); - } - - if (predicate == FCmpInst::FCMP_UNE || - predicate == FCmpInst::FCMP_ONE) { - res = pushcse(B.CreateNot(res)); - } - - replaceAndErase(cur, res); - return "CmpProductSplit"; - } - - // (a1*b1) + (a1*c1) + (a1*d1 ) + ... ?= 0 -> a1 * (b1 + c1 + d1 + - // ...) ?= 0 - if (auto S = isSum(fcmp->getOperand(1 - i))) { - SmallVector allOps; - auto combine = [](const SmallVector &lhs, - SmallVector rhs) { - SmallVector out; - for (auto v : lhs) { - bool seen = false; - for (auto &v2 : rhs) { - if (v == v2) { - v2 = nullptr; - seen = true; - break; - } - } - if (seen) { - out.push_back(v); - } - } - return out; - }; - auto subtract = [](SmallVector lhs, - const SmallVector &rhs) { - for (auto v : rhs) { - auto found = find(lhs, v); - assert(found != lhs.end()); - lhs.erase(found); - } - return lhs; - }; - bool seen = false; - bool legal = true; - for (auto op : callOperands(S)) { - auto P = isProduct(op); - if (!P) { - legal = false; - break; - } - if (!seen) { - allOps = callOperands(P); - seen = true; - continue; - } - allOps = combine(allOps, callOperands(P)); - } - - if (legal && allOps.size() > 0) { - SmallVector operands; - for (auto op : callOperands(S)) { - auto P = isProduct(op); - push(op); - auto sub = subtract(callOperands(P), allOps); - auto newprod = pushcse(B.CreateCall( - getProductIntrinsic(*F.getParent(), C->getType()), sub)); - operands.push_back(newprod); - } - auto newsum = pushcse(B.CreateCall( - getSumIntrinsic(*F.getParent(), C->getType()), operands)); - allOps.push_back(newsum); - auto fprod = pushcse(B.CreateCall( - getProductIntrinsic(*F.getParent(), C->getType()), allOps)); - auto fcmp = pushcse(B.CreateCmp(predicate, fprod, C)); - replaceAndErase(cur, fcmp); - return "CmpSumFactor"; - } - } - } - } - } - } - - if (auto fcmp = dyn_cast(cur)) { - auto predicate = fcmp->getPredicate(); - if (predicate == FCmpInst::FCMP_OEQ || predicate == FCmpInst::FCMP_UEQ || - predicate == FCmpInst::FCMP_UNE || predicate == FCmpInst::FCMP_ONE) { - for (int i = 0; i < 2; i++) - if (auto C = dyn_cast(fcmp->getOperand(i))) { - if (C->isZero()) { - // a + b == 0 -> ( (a == 0 & b == 0) || a == -b) - if (auto S = isSum(fcmp->getOperand(1 - i))) { - auto allOps = callOperands(S); - if (!llvm::any_of(allOps, guaranteedDataDependent)) { - auto eq_predicate = predicate; - if (predicate == FCmpInst::FCMP_UNE || - predicate == FCmpInst::FCMP_ONE) - eq_predicate = fcmp->getInversePredicate(); - - Value *op_checks = nullptr; - for (auto a : allOps) { - auto a_e0 = pushcse(B.CreateFCmp(eq_predicate, a, C)); - if (op_checks == nullptr) - op_checks = a_e0; - else - op_checks = pushcse(B.CreateAnd(op_checks, a_e0)); - } - SmallVector slice; - for (size_t i = 1; i < allOps.size(); i++) - slice.push_back(allOps[i]); - auto ane = pushcse(B.CreateFCmp( - eq_predicate, pushcse(B.CreateFNeg(allOps[0])), - pushcse(B.CreateCall(getFunctionFromCall(S), slice)))); - auto ori = pushcse(B.CreateOr(op_checks, ane)); - if (predicate == FCmpInst::FCMP_UNE || - predicate == FCmpInst::FCMP_ONE) { - ori = pushcse(B.CreateNot(ori)); - } - replaceAndErase(cur, ori); - return "Sum2ZeroSplit"; - } - } - } - } - } - } - - // (zext a) + (zext b) ?= 0 -> zext a ?= - zext b - if (auto icmp = dyn_cast(cur)) { - if (icmp->getPredicate() == CmpInst::ICMP_EQ || - icmp->getPredicate() == CmpInst::ICMP_NE) { - for (int i = 0; i < 2; i++) - if (auto C = dyn_cast(icmp->getOperand(i))) - if (C->isZero()) - if (auto add = dyn_cast(icmp->getOperand(1 - i))) - if (add->getOpcode() == Instruction::Add) - if (auto a0 = dyn_cast(add->getOperand(0))) - if (auto a1 = dyn_cast(add->getOperand(1))) - if (a0->getOperand(0)->getType() == - a1->getOperand(0)->getType() && - (isa(a0) || isa(a0))) { - auto cmp2 = pushcse(B.CreateCmp( - icmp->getPredicate(), a0, pushcse(B.CreateNeg(a1)))); - replaceAndErase(cur, cmp2); - return "CmpExt0Shuffle"; - } - } - } - - // sub 0, (zext i1 to N) -> sext i1 to N - // sub 0, (sext i1 to N) -> zext i1 to N - if (auto sub = dyn_cast(cur)) - if (sub->getOpcode() == Instruction::Sub) - if (auto C = dyn_cast(sub->getOperand(0))) - if (C->isZero()) - if (auto a0 = dyn_cast(sub->getOperand(1))) - if (a0->getOperand(0)->getType()->isIntegerTy(1)) { - - Value *tmp = nullptr; - if (isa(a0)) - tmp = pushcse(B.CreateSExt(a0->getOperand(0), a0->getType())); - else if (isa(a0)) - tmp = pushcse(B.CreateZExt(a0->getOperand(0), a0->getType())); - else - assert(0); - replaceAndErase(cur, tmp); - return "NegSZExtI1"; - } - - if ((cur->getOpcode() == Instruction::LShr || - cur->getOpcode() == Instruction::SDiv || - cur->getOpcode() == Instruction::UDiv) && - cur->isExact()) - if (auto C2 = dyn_cast(cur->getOperand(1))) - if (auto mul = dyn_cast(cur->getOperand(0))) { - // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if - // C2 divides C1 - if (mul->getOpcode() == Instruction::Mul) - for (int i0 = 0; i0 < 2; i0++) - if (auto C1 = dyn_cast(mul->getOperand(i0))) { - auto lhs = C1->getValue(); - APInt rhs = C2->getValue(); - if (cur->getOpcode() == Instruction::LShr) { - rhs = APInt(rhs.getBitWidth(), 1) << rhs; - } - - APInt div, rem; - if (cur->getOpcode() == Instruction::LShr || - cur->getOpcode() == Instruction::UDiv) - APInt::udivrem(lhs, rhs, div, rem); - else - APInt::sdivrem(lhs, rhs, div, rem); - if (rem == 0) { - auto res = pushcse(B.CreateMul( - mul->getOperand(1 - i0), - ConstantInt::get(cur->getType(), div), - "mdiv." + cur->getName(), mul->hasNoUnsignedWrap(), - mul->hasNoSignedWrap())); - push(mul); - replaceAndErase(cur, res); - return "IMulDivConst"; - } - } - // (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if - // C2 - if (mul->getOpcode() == Instruction::Add) - for (int i0 = 0; i0 < 2; i0++) - if (auto C1 = dyn_cast(mul->getOperand(i0))) { - auto lhs = C1->getValue(); - APInt rhs = C2->getValue(); - if (cur->getOpcode() == Instruction::LShr) { - rhs = APInt(rhs.getBitWidth(), 1) << rhs; - } - - APInt div, rem; - if (cur->getOpcode() == Instruction::LShr || - cur->getOpcode() == Instruction::UDiv) - APInt::udivrem(lhs, rhs, div, rem); - else - APInt::sdivrem(lhs, rhs, div, rem); - if (rem == 0 && ((mul->hasNoUnsignedWrap() && - (cur->getOpcode() == Instruction::LShr || - cur->getOpcode() == Instruction::UDiv)) || - (mul->hasNoSignedWrap() && - (cur->getOpcode() == Instruction::AShr || - cur->getOpcode() == Instruction::SDiv)))) { - auto res = pushcse(B.CreateAdd( - mul->getOperand(1 - i0), - ConstantInt::get(cur->getType(), div), - "madd." + cur->getName(), mul->hasNoUnsignedWrap(), - mul->hasNoSignedWrap())); - push(mul); - replaceAndErase(cur, res); - return "IAddDivConst"; - } - } - } - - // mul (mul a, const1), (mul b, const2) -> mul (mul a, b), (const1, const2) - if (cur->getOpcode() == Instruction::FMul) - if (cur->isFast()) - if (auto mul1 = dyn_cast(cur->getOperand(0))) - if (mul1->getOpcode() == Instruction::FMul && mul1->isFast()) - if (auto mul2 = dyn_cast(cur->getOperand(1))) - if (mul2->getOpcode() == Instruction::FMul && mul2->isFast()) { - for (auto i1 = 0; i1 < 2; i1++) - for (auto i2 = 0; i2 < 2; i2++) - if (isa(mul1->getOperand(i1))) - if (isa(mul2->getOperand(i2))) { - - auto n0 = pushcse( - B.CreateFMulFMF(mul1->getOperand(1 - i1), - mul2->getOperand(1 - i2), cur)); - auto n1 = pushcse(B.CreateFMulFMF( - mul1->getOperand(i1), mul2->getOperand(i2), cur)); - auto n2 = pushcse(B.CreateFMulFMF(n0, n1, cur)); - push(mul1); - push(mul2); - replaceAndErase(cur, n2); - return "MulMulConstConst"; - } - } - - // mul (mul a, const1), const2 -> mul a, (mul const1, const2) - if ((cur->getOpcode() == Instruction::FMul && cur->isFast()) || - cur->getOpcode() == Instruction::Mul) - for (auto i1 = 0; i1 < 2; i1++) - if (auto mul1 = dyn_cast(cur->getOperand(i1))) - if (((mul1->getOpcode() == Instruction::FMul && mul1->isFast())) || - mul1->getOpcode() == Instruction::FMul) - if (auto const2 = dyn_cast(cur->getOperand(1 - i1))) - for (auto i2 = 0; i2 < 2; i2++) - if (auto const1 = dyn_cast(mul1->getOperand(i2))) { - Value *res = nullptr; - if (cur->getOpcode() == Instruction::FMul) { - auto const3 = pushcse(B.CreateFMulFMF(const1, const2, mul1)); - res = pushcse( - B.CreateFMulFMF(mul1->getOperand(1 - i2), const3, cur)); - } else { - auto const3 = pushcse(B.CreateMul(const1, const2)); - res = pushcse(B.CreateMul(mul1->getOperand(1 - i2), const3)); - } - push(mul1); - replaceAndErase(cur, res); - return "MulConstConst"; - } - - if (auto fcmp = dyn_cast(cur)) { - if (fcmp->getPredicate() == FCmpInst::FCMP_OEQ) { - for (int i = 0; i < 2; i++) - if (auto C = dyn_cast(fcmp->getOperand(i))) { - if (C->isZero()) { - if (auto fmul = dyn_cast(fcmp->getOperand(1 - i))) { - // (a*b) == 0 -> (a == 0) || (b == 0) - if (fmul->getOpcode() == Instruction::FMul) { - auto ncmp1 = pushcse( - B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C)); - auto ncmp2 = pushcse( - B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(1), C)); - auto ori = pushcse(B.CreateOr(ncmp1, ncmp2)); - replaceAndErase(cur, ori); - return "CmpFMulSplit"; - } - // (a/b) == 0 -> (a == 0) - if (fmul->getOpcode() == Instruction::FDiv) { - auto ncmp1 = pushcse( - B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C)); - replaceAndErase(cur, ncmp1); - return "CmpFDivSplit"; - } - // (a - b) ?= 0 -> a ?= b - if (fmul->getOpcode() == Instruction::FSub) { - auto ncmp1 = pushcse(B.CreateFCmp(fcmp->getPredicate(), - fmul->getOperand(0), - fmul->getOperand(1))); - replaceAndErase(cur, ncmp1); - return "CmpFSubSplit"; - } - } - if (auto cast = dyn_cast(fcmp->getOperand(1 - i))) { - auto ncmp1 = pushcse(B.CreateICmp( - ICmpInst::ICMP_EQ, cast->getOperand(0), - ConstantInt::get(cast->getOperand(0)->getType(), 0))); - replaceAndErase(cur, ncmp1); - return "SFCmpToICmp"; - } - if (auto cast = dyn_cast(fcmp->getOperand(1 - i))) { - auto ncmp1 = pushcse(B.CreateICmp( - ICmpInst::ICMP_EQ, cast->getOperand(0), - ConstantInt::get(cast->getOperand(0)->getType(), 0))); - replaceAndErase(cur, ncmp1); - return "UFCmpToICmp"; - } - if (auto SI = dyn_cast(fcmp->getOperand(1 - i))) { - auto res = pushcse( - B.CreateSelect(SI->getCondition(), - pushcse(B.CreateCmp(fcmp->getPredicate(), C, - SI->getTrueValue())), - pushcse(B.CreateCmp(fcmp->getPredicate(), C, - SI->getFalseValue())))); - replaceAndErase(cur, res); - return "FCmpSelect"; - } - } - } - } - } - if (auto fcmp = dyn_cast(cur)) { - if (fcmp->getPredicate() == CmpInst::ICMP_EQ || - fcmp->getPredicate() == CmpInst::ICMP_NE || - fcmp->getPredicate() == CmpInst::FCMP_OEQ || - fcmp->getPredicate() == CmpInst::FCMP_ONE) { - - // a + c ?= a -> c ?= 0 , if fast - for (int i = 0; i < 2; i++) - if (auto inst = dyn_cast(fcmp->getOperand(i))) - if (inst->getOpcode() == Instruction::FAdd && inst->isFast()) - for (int i2 = 0; i2 < 2; i2++) - if (inst->getOperand(i2) == fcmp->getOperand(1 - i)) { - auto res = pushcse( - B.CreateCmp(fcmp->getPredicate(), inst->getOperand(1 - i2), - ConstantFP::get(inst->getType(), 0))); - replaceAndErase(cur, res); - return "CmpFAddSame"; - } - - // a == b -> a & b | !a & !b - // a != b -> a & !b | !a & b - if (fcmp->getOperand(0)->getType()->isIntegerTy(1)) { - auto a = fcmp->getOperand(0); - auto b = fcmp->getOperand(1); - if (fcmp->getPredicate() == CmpInst::ICMP_EQ) { - auto res = pushcse( - B.CreateOr(pushcse(B.CreateAnd(a, b)), - pushcse(B.CreateAnd(pushcse(B.CreateNot(a)), - pushcse(B.CreateNot(b)))))); - replaceAndErase(cur, res); - return "CmpI1EQ"; - } - if (fcmp->getPredicate() == CmpInst::ICMP_NE) { - auto res = pushcse( - B.CreateOr(pushcse(B.CreateAnd(pushcse(B.CreateNot(a)), b)), - pushcse(B.CreateAnd(a, pushcse(B.CreateNot(b)))))); - replaceAndErase(cur, res); - return "CmpI1NE"; - } - } - - for (int i = 0; i < 2; i++) - if (auto CI = dyn_cast(fcmp->getOperand(i))) - if (CI->isZero()) { - // a + a ?= 0 -> a ?= 0 - if (auto addI = dyn_cast(fcmp->getOperand(1 - i))) { - if (addI->getOperand(0) == addI->getOperand(1)) { - Value *res = pushcse( - B.CreateCmp(fcmp->getPredicate(), addI->getOperand(0), CI)); - replaceAndErase(cur, res); - return "CmpAddAdd"; - } - // (a-b) ?= 0 -> a ?= b - if (addI->getOpcode() == Instruction::Sub) { - auto ncmp1 = pushcse(B.CreateICmp(fcmp->getPredicate(), - addI->getOperand(0), - addI->getOperand(1))); - replaceAndErase(cur, ncmp1); - return "CmpISubSplit"; - } - } - } - - // (a * b) == (c * b) -> (a == c) || b == 0 - // (a * b) != (c * b) -> (a != c) && b != 0 - // auto S1 = SE.getSCEV(cur->getOperand(0)); - // auto S2 = SE.getSCEV(cur->getOperand(1)); - // llvm::errs() <<" attempting push: " << *cur << " S1: " << *S1 << " - // S2: " << *S2 << " and " << *cur->getOperand(0) << " " << - // *cur->getOperand(1) << "\n"; - if (auto mul1 = dyn_cast(cur->getOperand(0))) - if (auto mul2 = dyn_cast(cur->getOperand(1))) { - if (mul1->getOpcode() == Instruction::Mul && - mul2->getOpcode() == Instruction::Mul && - mul1->hasNoUnsignedWrap() && mul1->hasNoSignedWrap() && - mul2->hasNoUnsignedWrap() && mul2->hasNoSignedWrap()) { - for (int i = 0; i < 2; i++) { - if (mul1->getOperand(i) == mul2->getOperand(i)) { - Value *res = pushcse(B.CreateICmp(fcmp->getPredicate(), - mul1->getOperand(1 - i), - mul2->getOperand(1 - i))); - auto b = mul1->getOperand(i); - if (fcmp->getPredicate() == CmpInst::ICMP_EQ) { - Value *bZero = pushcse(B.CreateICmp( - CmpInst::ICMP_EQ, b, ConstantInt::get(b->getType(), 0))); - res = pushcse(B.CreateOr(res, bZero)); - } else { - Value *bZero = pushcse(B.CreateICmp( - ICmpInst::ICMP_NE, b, ConstantInt::get(b->getType(), 0))); - res = pushcse(B.CreateAnd(res, bZero)); - } - replaceAndErase(cur, res); - return "CmpMulCommon"; - } - } - } - // same as above but now with floats - if (mul1->getOpcode() == Instruction::FMul && - mul2->getOpcode() == Instruction::FMul && mul1->isFast() && - mul2->isFast()) { - for (int i = 0; i < 2; i++) { - if (mul1->getOperand(i) == mul2->getOperand(i)) { - Value *res = pushcse(B.CreateFCmp(fcmp->getPredicate(), - mul1->getOperand(1 - i), - mul2->getOperand(1 - i))); - auto b = mul1->getOperand(i); - if (fcmp->getPredicate() == CmpInst::FCMP_OEQ) { - Value *bZero = pushcse(B.CreateCmp( - CmpInst::FCMP_OEQ, b, ConstantFP::get(b->getType(), 0))); - res = pushcse(B.CreateOr(res, bZero)); - } else { - Value *bZero = pushcse(B.CreateCmp( - CmpInst::FCMP_ONE, b, ConstantFP::get(b->getType(), 0))); - res = pushcse(B.CreateAnd(res, bZero)); - } - replaceAndErase(cur, res); - return "CmpMulfCommon"; - } - } - } - - // (uitofp a ) ?= (uitofp b) -> a ?= b - for (auto cond : {Instruction::UIToFP, Instruction::SIToFP}) - if (mul1->getOpcode() == cond && mul2->getOpcode() == cond && - mul1->getOperand(0)->getType() == - mul2->getOperand(0)->getType()) { - Value *res = pushcse(B.CreateICmp( - fcmp->getPredicate() == CmpInst::FCMP_OEQ ? CmpInst::ICMP_EQ - : CmpInst::ICMP_NE, - mul1->getOperand(0), mul2->getOperand(0))); - replaceAndErase(cur, res); - return "CmpUIToFP"; - } - - // (zext a ) ?= (zext b) -> a ?= b - if (mul1->getOpcode() == Instruction::ZExt && - mul2->getOpcode() == Instruction::ZExt && - mul1->getOperand(0)->getType() == - mul2->getOperand(0)->getType()) { - Value *res = - pushcse(B.CreateICmp(fcmp->getPredicate(), mul1->getOperand(0), - mul2->getOperand(0))); - replaceAndErase(cur, res); - return "CmpZExt"; - } - - // (zext i1 a ) == (sext i1 b) -> (!a & !b) - // (zext i1 a ) != (sext i1 b) -> (a | b) - if (auto mul1 = dyn_cast(cur->getOperand(0))) - if (auto mul2 = dyn_cast(cur->getOperand(1))) - if (((mul1->getOpcode() == Instruction::ZExt && - mul2->getOpcode() == Instruction::SExt) || - (mul1->getOpcode() == Instruction::SExt && - mul2->getOpcode() == Instruction::ZExt)) && - mul1->getOperand(0)->getType() == - mul2->getOperand(0)->getType() && - mul1->getOperand(0)->getType()->isIntegerTy(1)) { - - Value *na = mul1->getOperand(0); - Value *nb = mul2->getOperand(0); - - if (fcmp->getPredicate() == ICmpInst::ICMP_EQ) { - na = pushcse(B.CreateNot(na)); - nb = pushcse(B.CreateNot(nb)); - } - - Value *res = nullptr; - if (fcmp->getPredicate() == ICmpInst::ICMP_EQ) - res = pushcse(B.CreateAnd(na, nb)); - else - res = pushcse(B.CreateOr(na, nb)); - - replaceAndErase(cur, res); - return "CmpZExtSExt"; - } - } - } - if (fcmp->getPredicate() == ICmpInst::ICMP_EQ) { - for (int i = 0; i < 2; i++) { - if (auto C = dyn_cast(fcmp->getOperand(i))) { - if (C->isZero()) { - if (auto fmul = dyn_cast(fcmp->getOperand(1 - i))) { - // (a*b) == 0 -> (a == 0) || (b == 0) - if (fmul->getOpcode() == Instruction::Mul) { - auto ncmp1 = pushcse( - B.CreateICmp(fcmp->getPredicate(), fmul->getOperand(0), C)); - auto ncmp2 = pushcse( - B.CreateICmp(fcmp->getPredicate(), fmul->getOperand(1), C)); - auto ori = pushcse(B.CreateOr(ncmp1, ncmp2)); - replaceAndErase(cur, ori); - return "CmpIMulSplit"; - } - } - } - } - } - } - } - - if (cur->getOpcode() == Instruction::FAdd) { - // add x, x -> mul 2.0 - if (cur->getOperand(0) == cur->getOperand(1) && cur->isFast()) { - auto res = pushcse(B.CreateFMulFMF( - cur->getOperand(0), ConstantFP::get(cur->getType(), 2.0), cur)); - replaceAndErase(cur, res); - return "AddToMul2"; - } - } - - if (cur->getOpcode() == Instruction::Add) { - // add x, (y * -1) -> sub x, y - for (int i = 0; i < 2; i++) { - if (auto mul1 = dyn_cast(cur->getOperand(i))) - if (mul1->getOpcode() == Instruction::Mul) { - for (int j = 0; j < 2; j++) { - if (auto C = dyn_cast(mul1->getOperand(j))) { - if (C->isMinusOne()) { - auto res = pushcse(B.CreateSub(cur->getOperand(1 - i), - mul1->getOperand(1 - j))); - push(mul1); - - replaceAndErase(cur, res); - return "AddToSub"; - } - } - } - } - } - } - - if (auto SI = dyn_cast(cur)) { - auto shouldMove = [](Value *v) { return isa(v); }; - - /* - // select c, 0, x -> fmul (uitofp (!c)), x - if (auto C1 = dyn_cast(SI->getTrueValue())) { - if (C1->isZero()) { - auto n = pushcse(B.CreateNot(SI->getCondition())); - auto val = pushcse(B.CreateUIToFP(n, SI->getType())); - auto res = pushcse(B.CreateFMul(val, SI->getFalseValue())); - if (auto I = dyn_cast(res)) - I->setFast(true); - replaceAndErase(cur, res); - return true; - } - } - // select c, x, 0 -> fmul (uitofp c), x - if (auto C1 = dyn_cast(SI->getFalseValue())) { - if (C1->isZero()) { - auto val = pushcse(B.CreateUIToFP(SI->getCondition(), SI->getType())); - auto res = pushcse(B.CreateFMul(val, SI->getTrueValue())); - if (auto I = dyn_cast(res)) - I->setFast(true); - replaceAndErase(cur, res); - return true; - } - } - */ - - // select c, (mul x y), 0 -> mul x, (select c, y, 0) - for (int i = 0; i < 2; i++) - if (auto inst = dyn_cast(SI->getOperand(1 + i))) - if (inst->getOpcode() == Instruction::Mul) - // inst->getOpcode() == Instruction::FMul) - if (auto C = dyn_cast(SI->getOperand(1 + (1 - i)))) - if ((isa(C) && cast(C)->isZero()) || - (isa(C) && cast(C)->isZero())) - for (int j = 0; j < 2; j++) - if (shouldMove(inst->getOperand(j))) { - auto x = inst->getOperand(j); - auto y = inst->getOperand(1 - j); - auto isel = pushcse(B.CreateSelect( - SI->getCondition(), (i == 0) ? y : C, (i == 0) ? C : y, - "smulmove." + SI->getName())); - Value *imul; - if (cur->getType()->isIntegerTy()) - imul = pushcse(B.CreateMul(isel, x, "", - inst->hasNoUnsignedWrap(), - inst->hasNoSignedWrap())); - else - imul = pushcse(B.CreateFMulFMF(isel, x, inst, "")); - - replaceAndErase(cur, imul); - return "SelMulMove"; - } - - // select c, (sitofp x), (sitofp y) -> sitofp (select c, x, y) - // select c, c5, (sitofp y) -> sitofp (select c, c5, y) - { - Value *ops[2] = {nullptr, nullptr}; - bool legal = true; - for (int i = 0; i < 2; i++) { - if (isa(SI->getOperand(1 + i))) { - ops[i] = nullptr; - continue; - } - if (auto CI = dyn_cast(SI->getOperand(1 + i))) { - if (CI->getOpcode() == Instruction::SIToFP) { - ops[i] = CI->getOperand(0); - continue; - } - } - legal = false; - break; - } - for (int i = 0; i < 2; i++) { - if (!ops[i] && ops[1 - i]) - ops[i] = ConstantInt::get(ops[1 - i]->getType(), 0); - } - for (int i = 0; i < 2; i++) { - if (ops[i] == nullptr || ops[i]->getType() != ops[0]->getType()) { - legal = false; - break; - } - } - if (legal) { - auto isel = pushcse(B.CreateSelect(SI->getCondition(), ops[0], ops[1], - "seltofp." + SI->getName())); - auto res = pushcse(B.CreateSIToFP(isel, SI->getType())); - - replaceAndErase(cur, res); - return "SelSIMerge"; - } - } - } - - if (cur->getOpcode() == Instruction::Mul) { - for (int i = 0; i < 2; i++) { - // mul (x, 1) -> x - if (auto C = dyn_cast(cur->getOperand(i))) - if (C->isOne()) { - replaceAndErase(cur, cur->getOperand(1 - i)); - return "MulIdent"; - } - - // mul (zext i1 x), y -> mul (zext i1 x) y[x->1] - if (auto Z = dyn_cast(cur->getOperand(i))) - if (Z->getOperand(0)->getType()->isIntegerTy(1)) { - auto prev = cur->getOperand(1 - i); - auto next = replace(prev, Z->getOperand(0), - ConstantInt::getTrue(cur->getContext())); - if (next != prev) { - auto res = pushcse(B.CreateMul(Z, next, "postmul." + cur->getName(), - cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap())); - replaceAndErase(cur, res); - return "MulReplaceZExt"; - } - } - } - - /* - // mul x, (select c, 0, y) -> select c (mul x 0), (mul x y) - for (int i=0; i<2; i++) - if (auto SI = dyn_cast(cur->getOperand(i))) - for (int j=0; j<2; j++) - if (auto CI = dyn_cast(SI->getOperand(1+j))) - if (CI->isZero()) { - auto tval = (j == 0) ? CI : - pushcse(B.CreateMul(SI->getTrueValue(), cur->getOperand(1-i), "tval." + - cur->getName(), cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); auto - fval = (j == 1) ? CI : pushcse(B.CreateMul(SI->getFalseValue(), - cur->getOperand(1-i), "fval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap())); - - auto res = pushcse(B.CreateSelect(SI->getCondition(), tval, fval)); - - replaceAndErase(cur, res); - return true; - } - */ - - // mul (sub x, y), -c -> mul (sub, y, x), c - for (int i = 0; i < 2; i++) - if (auto inst = dyn_cast(cur->getOperand(i))) - if (inst->getOpcode() == Instruction::Sub) - if (auto CI = dyn_cast(cur->getOperand(1 - i))) - if (CI->isNegative()) { - auto sub2 = pushcse(B.CreateSub( - inst->getOperand(1), inst->getOperand(0), "", - inst->hasNoUnsignedWrap(), inst->hasNoSignedWrap())); - auto mul2 = pushcse(B.CreateMul( - sub2, ConstantInt::get(CI->getType(), -CI->getValue()), "", - cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); - - replaceAndErase(cur, mul2); - return "MulSubNegConst"; - } - } - - if (cur->getOpcode() == Instruction::Sub) - if (auto CI = dyn_cast(cur->getOperand(0))) - if (CI->isZero()) - if (auto zext = dyn_cast(cur->getOperand(1))) { - // sub 0, (zext i1 x) -> sext x - if (zext->getOpcode() == Instruction::ZExt && - zext->getOperand(0)->getType()->isIntegerTy(1)) { - auto res = - pushcse(B.CreateSExt(zext->getOperand(0), cur->getType())); - replaceAndErase(cur, res); - return "SubZExt"; - } - // sub 0, (mul nsw nuw constant, x) -> mul nsw nuw -constant, x - if (zext->getOpcode() == Instruction::Mul && - zext->hasNoUnsignedWrap() && zext->hasNoSignedWrap()) { - for (int i = 0; i < 2; i++) - if (auto CI = dyn_cast(zext->getOperand(i))) { - auto res = pushcse(B.CreateMul( - zext->getOperand(1 - i), - ConstantInt::get(CI->getType(), -CI->getValue()), - "neg." + zext->getName(), true, true)); - replaceAndErase(cur, res); - return "SubMulConstant"; - } - } - } - - // add (zext (and c1, x) ), (zext (and c1, y)) -> select c1, (add (zext x), - // (zext y)), 0 - /* - if (cur->getOpcode() == Instruction::Add || - cur->getOpcode() == Instruction::Sub || - cur->getOpcode() == Instruction::Mul) - if (auto inst1 = dyn_cast(cur->getOperand(0))) - if (auto inst2 = dyn_cast(cur->getOperand(1))) - if (inst1->getOpcode() == Instruction::ZExt && inst2->getOpcode() == - Instruction::ZExt) if (auto and1 = - dyn_cast(inst1->getOperand(0))) if (auto and2 = - dyn_cast(inst2->getOperand(0))) if - (and1->getType()->isIntegerTy(1) && and2->getType()->isIntegerTy(1) && - and1->getOpcode() == Instruction::And && and2->getOpcode() == - Instruction::And) { bool done = false; for (int i1=0; i1<2; i1++) for (int - i2=0; i2<2; i2++) if (and1->getOperand(i1) == and2->getOperand(i2)) { auto - c1 = and1->getOperand(i1); auto x = and1->getOperand(1-i1); x = - pushcse(B.CreateZExt(x, inst1->getType())); auto y = - and2->getOperand(1-i2); - - y = pushcse(B.CreateZExt(y, inst2->getType())); - - Value *res = nullptr; - switch (cur->getOpcode()) { - case Instruction::Add: - res = pushcse(B.CreateAdd(x, y, "", cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap())); break; case Instruction::Sub: res = B.CreateSub(x, - y, - "", cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap()); break; case - Instruction::Mul: res = B.CreateMul(x, y, "", cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap()); break; default: llvm_unreachable("Illegal opcode"); - } - res = pushcse(B.CreateSelect(c1, res, - Constant::getNullValue(cur->getType()))); - - replaceAndErase(cur, res); - return; - } - } - */ - - // add (select %c c0, x), (select %c, c1, y) -> select %c, (add c0, c1), - // (add x, y) and for sub/mul/cmp - if (cur->getOpcode() == Instruction::Add || - cur->getOpcode() == Instruction::Sub || - cur->getOpcode() == Instruction::Mul || - cur->getOpcode() == Instruction::FAdd || - cur->getOpcode() == Instruction::FSub || - cur->getOpcode() == Instruction::FMul || - // cur->getOpcode() == Instruction::SIToFP || - // cur->getOpcode() == Instruction::UIToFP || - cur->getOpcode() == Instruction::ICmp || - cur->getOpcode() == Instruction::FCmp) { - - Value *SI1cond = nullptr; - Value *SI1tval = nullptr; - Value *SI1fval = nullptr; - if (auto SI1 = dyn_cast(cur->getOperand(0))) { - SI1cond = SI1->getCondition(); - SI1tval = SI1->getTrueValue(); - SI1fval = SI1->getFalseValue(); - } - if (auto SI1 = dyn_cast(cur->getOperand(0))) - if (SI1->getOperand(0)->getType()->isIntegerTy(1)) { - SI1cond = SI1->getOperand(0); - SI1tval = SI1; - SI1fval = ConstantInt::get(SI1->getType(), 0); - } - if (auto SI1 = dyn_cast(cur->getOperand(0))) - if (SI1->getOperand(0)->getType()->isIntegerTy(1)) { - SI1cond = SI1->getOperand(0); - SI1tval = SI1; - SI1fval = ConstantInt::get(SI1->getType(), 0); - } - Value *SI2cond = nullptr; - Value *SI2tval = nullptr; - Value *SI2fval = nullptr; - - auto op2 = cur->getOperand((cur->getOpcode() == Instruction::SIToFP || - cur->getOpcode() == Instruction::UIToFP) - ? 0 - : 1); - if (auto SI2 = dyn_cast(op2)) { - SI2cond = SI2->getCondition(); - SI2tval = SI2->getTrueValue(); - SI2fval = SI2->getFalseValue(); - } - if (auto SI2 = dyn_cast(op2)) - if (SI2->getOperand(0)->getType()->isIntegerTy(1)) { - SI2cond = SI2->getOperand(0); - SI2tval = SI2; - SI2fval = ConstantInt::get(SI2->getType(), 0); - } - if (auto SI2 = dyn_cast(op2)) - if (SI2->getOperand(0)->getType()->isIntegerTy(1)) { - SI2cond = SI2->getOperand(0); - SI2tval = SI2; - SI2fval = ConstantInt::get(SI2->getType(), 0); - } - - if (SI1cond && SI2cond && (SI1cond == SI2cond || isNot(SI1cond, SI2cond))) - if ((SI1cond == SI2cond && - ((isa(SI1tval) && isa(SI2tval)) || - (isa(SI1fval) && isa(SI2fval)))) || - (SI1cond != SI2cond && - ((isa(SI1tval) && isa(SI2fval)) || - (isa(SI1fval) && isa(SI2tval)))) - - ) { - Value *tval = nullptr; - Value *fval = nullptr; - bool inverted = SI1cond != SI2cond; - switch (cur->getOpcode()) { - case Instruction::SIToFP: - tval = - B.CreateSIToFP(SI1tval, cur->getType(), "tval." + cur->getName()); - fval = - B.CreateSIToFP(SI1fval, cur->getType(), "fval." + cur->getName()); - break; - case Instruction::UIToFP: - tval = - B.CreateUIToFP(SI1tval, cur->getType(), "tval." + cur->getName()); - fval = - B.CreateUIToFP(SI1fval, cur->getType(), "fval." + cur->getName()); - break; - case Instruction::FAdd: - tval = B.CreateFAddFMF(SI1tval, inverted ? SI2fval : SI2tval, cur, - "tval." + cur->getName()); - fval = B.CreateFAddFMF(SI1fval, inverted ? SI2tval : SI2fval, cur, - "fval." + cur->getName()); - break; - case Instruction::FSub: - tval = B.CreateFSubFMF(SI1tval, inverted ? SI2fval : SI2tval, cur, - "tval." + cur->getName()); - fval = B.CreateFSubFMF(SI1fval, inverted ? SI2tval : SI2fval, cur, - "fval." + cur->getName()); - break; - case Instruction::FMul: - tval = B.CreateFMulFMF(SI1tval, inverted ? SI2fval : SI2tval, cur, - "tval." + cur->getName()); - fval = B.CreateFMulFMF(SI1fval, inverted ? SI2tval : SI2fval, cur, - "fval." + cur->getName()); - break; - case Instruction::Add: - tval = B.CreateAdd(SI1tval, inverted ? SI2fval : SI2tval, - "tval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap()); - fval = B.CreateAdd(SI1fval, inverted ? SI2tval : SI2fval, - "fval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap()); - break; - case Instruction::Sub: - tval = B.CreateSub(SI1tval, inverted ? SI2fval : SI2tval, - "tval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap()); - fval = B.CreateSub(SI1fval, inverted ? SI2tval : SI2fval, - "fval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap()); - break; - case Instruction::Mul: - tval = B.CreateMul(SI1tval, inverted ? SI2fval : SI2tval, - "tval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap()); - fval = B.CreateMul(SI1fval, inverted ? SI2tval : SI2fval, - "fval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap()); - break; - case Instruction::ICmp: - case Instruction::FCmp: - tval = B.CreateCmp(cast(cur)->getPredicate(), SI1tval, - inverted ? SI2fval : SI2tval, - "tval." + cur->getName()); - fval = B.CreateCmp(cast(cur)->getPredicate(), SI1fval, - inverted ? SI2tval : SI2fval, - "fval." + cur->getName()); - break; - default: - llvm_unreachable("illegal opcode"); - } - tval = pushcse(tval); - fval = pushcse(fval); - - auto res = pushcse( - B.CreateSelect(SI1cond, tval, fval, "selmerge." + cur->getName())); - - push(cur->getOperand(0)); - push(cur->getOperand(1)); - replaceAndErase(cur, res); - return "BinopSelFuse"; - } - } - - /* - // and (i == c), (i != d) -> and (i == c) && (c != d) - if (cur->getOpcode() == Instruction::And) { - auto lhs = replace(cur->getOperand(0), cur->getOperand(1), - ConstantInt::getTrue(cur->getContext())); - auto rhs = replace(cur->getOperand(1), cur->getOperand(0), - ConstantInt::getTrue(cur->getContext())); - if (lhs != cur->getOperand(0) || rhs != cur->getOperand(1)) { - auto res = pushcse(B.CreateAnd(lhs, rhs, "postand." + cur->getName())); - replaceAndErase(cur, res); - return "AndReplace2"; - } - } - */ - - // and a, (or q, (not a)) -> and a q - if (cur->getOpcode() == Instruction::And) { - for (size_t i1 = 0; i1 < 2; i1++) - if (auto inst2 = dyn_cast(cur->getOperand(1 - i1))) - if (inst2->getOpcode() == Instruction::Or) - for (size_t i2 = 0; i2 < 2; i2++) - if (isNot(cur->getOperand(i1), inst2->getOperand(i2))) { - auto q = inst2->getOperand(1 - i2); - cur->setOperand(1 - i1, q); - push(cur); - push(q); - push(inst2); - push(cur->getOperand(i1)); - push(inst2->getOperand(i2)); - Q.insert(cur); - for (auto U : cur->users()) - push(U); - return "AndOrProp"; - } - } - - // and (and a, b), a) -> and a, b - if (cur->getOpcode() == Instruction::And) { - for (size_t i1 = 0; i1 < 2; i1++) - if (auto inst2 = dyn_cast(cur->getOperand(i1))) - if (inst2->getOpcode() == Instruction::And) - for (size_t i2 = 0; i2 < 2; i2++) - if (inst2->getOperand(i2) == cur->getOperand(1 - i1)) { - replaceAndErase(cur, inst2); - return "AndAndProp"; - } - } - - // or a, (and q, (not a)) -> and a q - if (cur->getOpcode() == Instruction::And) { - for (size_t i1 = 0; i1 < 2; i1++) - if (auto inst2 = dyn_cast(cur->getOperand(1 - i1))) - if (inst2->getOpcode() == Instruction::Or) - for (size_t i2 = 0; i2 < 2; i2++) - if (isNot(cur->getOperand(i1), inst2->getOperand(i2))) { - auto q = inst2->getOperand(1 - i2); - cur->setOperand(1 - i1, q); - push(cur); - push(q); - push(inst2); - push(cur->getOperand(i1)); - push(inst2->getOperand(i2)); - Q.insert(cur); - for (auto U : cur->users()) - push(U); - return "OrAndProp"; - } - } - - // and ( (a +/- b) != c ), ( (d +/- b) != c ) -> and ( a != (c -/+ b) ), ( - // d != (c -/+ b) ) - // also with or - if (cur->getOpcode() == Instruction::And || - cur->getOpcode() == Instruction::Or) { - for (auto cmpOp : {ICmpInst::ICMP_EQ, ICmpInst::ICMP_NE}) - for (auto interOp : {Instruction::Add, Instruction::Sub}) - if (auto cmp1 = dyn_cast(cur->getOperand(0))) - if (auto cmp2 = dyn_cast(cur->getOperand(1))) - for (size_t i1 = 0; i1 < 2; i1++) - for (size_t i2 = 0; i2 < 2; i2++) - if (cmp1->getOperand(1 - i1) == cmp2->getOperand(1 - i2) && - cmp1->getPredicate() == cmpOp && - cmp2->getPredicate() == cmpOp) - if (auto add1 = dyn_cast(cmp1->getOperand(i1))) - if (auto add2 = dyn_cast(cmp2->getOperand(i2))) - if (add1->getOpcode() == interOp && - add2->getOpcode() == interOp) - for (size_t ia = 0; ia < 2; ia++) - if (add1->getOperand(ia) == add2->getOperand(ia)) { - - auto b = add1->getOperand(ia); - auto c = cmp1->getOperand(1 - i1); - auto a = add1->getOperand(1 - ia); - auto d = add2->getOperand(1 - ia); - - Value *res = nullptr; - if (interOp == Instruction::Add) - res = pushcse(B.CreateSub(ia == 0 ? b : c, - ia == 0 ? c : b)); - else - res = pushcse(B.CreateAdd(ia == 0 ? b : c, - ia == 0 ? c : b)); - - auto lhs = pushcse(B.CreateCmp(cmpOp, a, res)); - auto rhs = pushcse(B.CreateCmp(cmpOp, d, res)); - - Value *fres = nullptr; - if (cur->getOpcode() == Instruction::And) - fres = pushcse(B.CreateAnd(lhs, rhs)); - else - fres = pushcse(B.CreateOr(lhs, rhs)); - - replaceAndErase(cur, fres); - return "AndLinearShift"; - } - } - - // and ( expr == c1 ), ( expr == c2 ) and c1 != c2 -> false - if (cur->getOpcode() == Instruction::And) { - for (auto cmpOp : {ICmpInst::ICMP_EQ}) - if (auto cmp1 = dyn_cast(cur->getOperand(0))) - if (auto cmp2 = dyn_cast(cur->getOperand(1))) - for (size_t i1 = 0; i1 < 2; i1++) - for (size_t i2 = 0; i2 < 2; i2++) - if (cmp1->getOperand(1 - i1) == cmp2->getOperand(1 - i2) && - cmp1->getPredicate() == cmpOp && - cmp2->getPredicate() == cmpOp) { - auto c1 = SE.getSCEV(cmp1->getOperand(i1)); - auto c2 = SE.getSCEV(cmp2->getOperand(i2)); - auto m = SE.getMinusSCEV(c1, c2, SCEV::NoWrapMask); - if (auto C = dyn_cast(m)) { - // if c1 == c2 don't need the and they are equivalent - if (C->getValue()->isZero()) { - push(cmp1); - push(cmp2); - replaceAndErase(cur, cmp1); - return "AndEQExpr"; - } else { - // if non one constant they must be distinct. - replaceAndErase(cur, - ConstantInt::getFalse(cur->getContext())); - return "AndNEExpr"; - } - } - } - } - - // (a | b) == 0 -> a == 0 & b == 0 - if (auto icmp = dyn_cast(cur)) - if (icmp->getPredicate() == ICmpInst::ICMP_EQ && - cur->getType()->isIntegerTy(1)) - for (int i = 0; i < 2; i++) - if (auto C = dyn_cast(icmp->getOperand(i))) - if (C->isZero()) - if (auto z = dyn_cast(icmp->getOperand(1 - i))) - if (z->getOpcode() == BinaryOperator::Or) { - auto a0 = pushcse(B.CreateICmpEQ(z->getOperand(0), C)); - auto b0 = pushcse(B.CreateICmpEQ(z->getOperand(1), C)); - auto res = pushcse(B.CreateAnd(a0, b0)); - push(z); - push(icmp); - replaceAndErase(cur, res); - return "OrEQZero"; - } - - // add (mul a b), (mul c, b) -> mul (add a, c), b - if (cur->getOpcode() == Instruction::Sub || - cur->getOpcode() == Instruction::Add) { - if (auto mul1 = dyn_cast(cur->getOperand(0))) - if (auto mul2 = dyn_cast(cur->getOperand(1))) - if ((mul1->getOpcode() == Instruction::Mul && - mul2->getOpcode() == Instruction::Mul) || - (mul1->getOpcode() == Instruction::FMul && - mul2->getOpcode() == Instruction::FMul && mul1->isFast() && - mul2->isFast() && cur->isFast())) { - for (int i1 = 0; i1 < 2; i1++) - for (int i2 = 0; i2 < 2; i2++) { - if (mul1->getOperand(i1) == mul2->getOperand(i2)) { - Value *res = nullptr; - switch (cur->getOpcode()) { - case Instruction::Add: - res = B.CreateAdd(mul1->getOperand(1 - i1), - mul2->getOperand(1 - i2)); - break; - case Instruction::Sub: - res = B.CreateSub(mul1->getOperand(1 - i1), - mul2->getOperand(1 - i2)); - break; - case Instruction::FAdd: - res = B.CreateFAddFMF(mul1->getOperand(1 - i1), - mul2->getOperand(1 - i2), cur); - break; - case Instruction::FSub: - res = B.CreateFSubFMF(mul1->getOperand(1 - i1), - mul2->getOperand(1 - i2), cur); - break; - default: - llvm_unreachable("Illegal opcode"); - } - res = pushcse(res); - Value *res2 = nullptr; - if (cur->getType()->isIntegerTy()) - res2 = B.CreateMul( - res, mul1->getOperand(i1), "", - mul1->hasNoUnsignedWrap() && mul1->hasNoUnsignedWrap(), - mul2->hasNoSignedWrap() && mul2->hasNoSignedWrap()); - else - res2 = B.CreateFMulFMF(res, mul1->getOperand(i1), cur); - - res2 = pushcse(res2); - - replaceAndErase(cur, res2); - return "InvDistributive"; - } - } - } - } - - // fadd (ext a), (ext b) -> ext (a + b) - // fsub (ext a), (ext b) -> ext (a - b) - // fmul (ext a), (ext b) -> ext (a * b) - if (cur->getOpcode() == Instruction::FSub || - cur->getOpcode() == Instruction::FAdd || - cur->getOpcode() == Instruction::FMul || - cur->getOpcode() == Instruction::FNeg || - (isSum(cur) && callOperands(cast(cur)).size() == 2)) { - auto opcode = cur->getOpcode(); - if (isSum(cur)) - opcode = Instruction::FAdd; - auto Ty = B.getInt64Ty(); - SmallPtrSet temporaries; - SmallVector precasts; - Value *lhs = nullptr; - - Value *prelhs = (cur->getOpcode() == Instruction::FNeg) - ? ConstantFP::get(cur->getType(), 0.0) - : cur->getOperand(0); - Value *prerhs = (cur->getOpcode() == Instruction::FNeg) - ? cur->getOperand(0) - : cur->getOperand(1); - - APInt minval(64, 0); - APInt maxval(64, 0); - if (auto C = dyn_cast(prelhs)) { - APSInt Tmp(64); - bool isExact = false; - C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero, - &isExact); - if (isExact || C->isZero()) { - minval = maxval = Tmp; - lhs = ConstantInt::get(Ty, Tmp); - } - } - if (auto ext = dyn_cast(prelhs)) { - if (ext->getOpcode() == Instruction::UIToFP || - ext->getOpcode() == Instruction::SIToFP) { - precasts.push_back(ext); - auto ity = cast(ext->getOperand(0)->getType()); - bool md = false; - if (auto I = dyn_cast(ext->getOperand(0))) - if (auto MD = hasMetadata(I, LLVMContext::MD_range)) { - md = true; - minval = - cast( - cast(MD->getOperand(0))->getValue()) - ->getValue() - .zextOrTrunc(64); - maxval = - cast( - cast(MD->getOperand(1))->getValue()) - ->getValue() - .zextOrTrunc(64); - } - if (!md) { - if (ext->getOpcode() == Instruction::UIToFP) - maxval = APInt::getMaxValue(ity->getBitWidth()).zextOrTrunc(64); - else { - maxval = - APInt::getSignedMaxValue(ity->getBitWidth()).zextOrTrunc(64); - minval = - APInt::getSignedMinValue(ity->getBitWidth()).zextOrTrunc(64); - } - } - if (ext->getOperand(0)->getType() == Ty) - lhs = ext->getOperand(0); - else if (ity->getBitWidth() < Ty->getBitWidth()) { - if (ext->getOpcode() == Instruction::UIToFP) - lhs = B.CreateZExt(ext->getOperand(0), Ty); - else - lhs = B.CreateSExt(ext->getOperand(0), Ty); - if (auto I = dyn_cast(lhs)) - if (I != ext->getOperand(0)) - temporaries.insert(I); - } - } - } - - Value *rhs = nullptr; - - if (auto C = dyn_cast(prerhs)) { - APSInt Tmp(64); - bool isExact = false; - C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero, - &isExact); - if (isExact || C->isZero()) { - rhs = ConstantInt::get(Ty, Tmp); - switch (opcode) { - case Instruction::FAdd: - minval += Tmp; - maxval += Tmp; - break; - case Instruction::FSub: - case Instruction::FNeg: - minval -= Tmp; - maxval -= Tmp; - break; - case Instruction::FMul: - minval *= Tmp; - maxval *= Tmp; - break; - default: - llvm_unreachable("Illegal opcode"); - } - } - } - if (auto ext = dyn_cast(prerhs)) { - if (ext->getOpcode() == Instruction::UIToFP || - ext->getOpcode() == Instruction::SIToFP) { - precasts.push_back(ext); - auto ity = cast(ext->getOperand(0)->getType()); - bool md = false; - APInt rhsMin(64, 0); - APInt rhsMax(64, 0); - if (auto I = dyn_cast(ext->getOperand(0))) - if (auto MD = hasMetadata(I, LLVMContext::MD_range)) { - md = true; - rhsMin = - cast( - cast(MD->getOperand(0))->getValue()) - ->getValue() - .zextOrTrunc(64); - rhsMax = - cast( - cast(MD->getOperand(1))->getValue()) - ->getValue() - .zextOrTrunc(64); - } - if (!md) { - if (ext->getOpcode() == Instruction::UIToFP) { - rhsMax = APInt::getMaxValue(ity->getBitWidth()).zextOrTrunc(64); - rhsMin = APInt(64, 0); - } else { - rhsMax = - APInt::getSignedMaxValue(ity->getBitWidth()).zextOrTrunc(64); - rhsMin = - APInt::getSignedMinValue(ity->getBitWidth()).zextOrTrunc(64); - } - } - switch (opcode) { - case Instruction::FAdd: - minval += rhsMin; - maxval += rhsMax; - break; - case Instruction::FSub: - case Instruction::FNeg: - minval -= rhsMax; - maxval -= rhsMin; - break; - case Instruction::FMul: { - auto minf = [&](APInt a, APInt b) { return a.sle(b) ? a : b; }; - auto maxf = [&](APInt a, APInt b) { return a.sle(b) ? b : b; }; - minval = minf( - minval * rhsMin, - minf(minval * rhsMax, minf(maxval * rhsMin, maxval * rhsMax))); - maxval = maxf( - minval * rhsMin, - maxf(minval * rhsMax, maxf(maxval * rhsMin, maxval * rhsMax))); - break; - } - default: - llvm_unreachable("Illegal opcode"); - } - if (ext->getOperand(0)->getType() == Ty) - rhs = ext->getOperand(0); - else if (ity->getBitWidth() < Ty->getBitWidth()) { - if (ext->getOpcode() == Instruction::UIToFP) - rhs = B.CreateZExt(ext->getOperand(0), Ty); - else - rhs = B.CreateSExt(ext->getOperand(0), Ty); - if (auto I = dyn_cast(rhs)) - if (I != ext->getOperand(0)) - temporaries.insert(I); - } - } - } - - if (lhs && rhs) { - Value *res = nullptr; - if (temporaries.count(dyn_cast(lhs))) - lhs = pushcse(lhs); - if (temporaries.count(dyn_cast(rhs))) - rhs = pushcse(rhs); - switch (opcode) { - case Instruction::FAdd: - res = B.CreateAdd(lhs, rhs, "", false, true); - break; - case Instruction::FSub: - case Instruction::FNeg: - res = B.CreateSub(lhs, rhs, "", false, true); - break; - case Instruction::FMul: - res = B.CreateMul(lhs, rhs, "", false, true); - break; - default: - llvm_unreachable("Illegal opcode"); - } - res = pushcse(res); - for (auto I : precasts) - push(I); - /* - if (auto I = dyn_cast(res)) { - Q.insert(I); - Metadata *vals[] = {(Metadata *)ConstantAsMetadata::get( - ConstantInt::get(Ty, minval)), - (Metadata *)ConstantAsMetadata::get( - ConstantInt::get(Ty, maxval))}; - I->setMetadata(LLVMContext::MD_range, - MDNode::get(I->getContext(), vals)); - } - */ - auto ext = pushcse(B.CreateSIToFP(res, cur->getType())); - replaceAndErase(cur, ext); - return "BinopExtToExtBinop"; - - } else { - for (auto I : temporaries) - I->eraseFromParent(); - } - } - - // select(cond, const1, b) ?= const2 -> select(cond, const1 ?= const2, b ?= - // const2) - if (auto fcmp = dyn_cast(cur)) - for (int i = 0; i < 2; i++) - if (auto const2 = dyn_cast(fcmp->getOperand(i))) - if (auto sel = dyn_cast(fcmp->getOperand(1 - i))) - if (isa(sel->getTrueValue()) || - isa(sel->getFalseValue())) { - auto tval = pushcse(B.CreateFCmp(fcmp->getPredicate(), - sel->getTrueValue(), const2)); - auto fval = pushcse(B.CreateFCmp(fcmp->getPredicate(), - sel->getFalseValue(), const2)); - auto res = pushcse(B.CreateSelect(sel->getCondition(), tval, fval)); - replaceAndErase(cur, res); - return "FCmpSelectConst"; - } - - // mul (mul a, const), b:not_sparse_or_const -> mul (mul a, b), const - // note we avoid the case where b = (mul a, const) since otherwise - // we create an infinite recursion - // and also we make sure b isn't sparse, since sparse is the first - // precedence for pushing, then constant, then others - if (cur->getOpcode() == Instruction::FMul) - if (cur->isFast() && cur->getOperand(0) != cur->getOperand(1)) - for (auto ic = 0; ic < 2; ic++) - if (auto mul = dyn_cast(cur->getOperand(ic))) - if (mul->getOpcode() == Instruction::FMul && mul->isFast()) { - auto b = cur->getOperand(1 - ic); - if (!isa(b) && !directlySparse(b)) { - - for (int i = 0; i < 2; i++) - if (auto C = dyn_cast(mul->getOperand(i))) { - auto n0 = - pushcse(B.CreateFMulFMF(mul->getOperand(1 - i), b, mul)); - auto n1 = pushcse(B.CreateFMulFMF(n0, C, cur)); - push(mul); - - replaceAndErase(cur, n1); - return "MulMulConst"; - } - } - } - - // (mul c, a) +/- (mul c, b) -> mul c, (a +/- b) - if (cur->getOpcode() == Instruction::FAdd || - cur->getOpcode() == Instruction::FSub) { - if (auto mul1 = dyn_cast(cur->getOperand(0))) { - if (mul1->getOpcode() == Instruction::FMul && mul1->isFast()) { - if (auto mul2 = dyn_cast(cur->getOperand(1))) { - if (mul2->getOpcode() == Instruction::FMul && mul2->isFast()) { - for (int i = 0; i < 2; i++) { - for (int j = 0; j < 2; j++) { - if (mul1->getOperand(i) == mul2->getOperand(j)) { - auto c = mul1->getOperand(i); - auto a = mul1->getOperand(1 - i); - auto b = mul2->getOperand(1 - j); - Value *intermediate = nullptr; - - if (cur->getOpcode() == Instruction::FAdd) - intermediate = pushcse(B.CreateFAddFMF(a, b, cur)); - else - intermediate = pushcse(B.CreateFSubFMF(a, b, cur)); - - auto res = pushcse(B.CreateFMulFMF(c, intermediate, cur)); - push(mul1); - push(mul2); - replaceAndErase(cur, res); - return "FAddMulConstMulConst"; - } - } - } - } - } - } - } - } - - // fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))), - // (sitofp b) - - if (cur->getOpcode() == Instruction::FMul && cur->isFast()) { - for (int i = 0; i < 2; i++) - if (auto z = dyn_cast(cur->getOperand(i))) - if (isa(z) || isa(z)) - if (auto imul = dyn_cast(z->getOperand(0))) - if (imul->getOpcode() == Instruction::Mul) - for (int j = 0; j < 2; j++) - if (auto c = dyn_cast(imul->getOperand(j))) { - auto b = imul->getOperand(1 - j); - auto a = cur->getOperand(1 - i); - - auto c_fp = pushcse(B.CreateSIToFP(c, cur->getType())); - auto b_fp = pushcse(B.CreateSIToFP(b, cur->getType())); - auto n_mul = pushcse(B.CreateFMulFMF(a, c_fp, cur)); - auto res = pushcse( - B.CreateFMulFMF(n_mul, b_fp, cur, cur->getName())); - push(imul); - push(z); - replaceAndErase(cur, res); - return "FMulIMulConstRotate"; - } - } - - if (cur->getOpcode() == Instruction::FDiv) { - Value *prelhs = cur->getOperand(0); - Value *b = cur->getOperand(1); - - // fdiv (sitofp a), b -> select (a == 0), 0 [ (fdiv 1 / b) * sitofp a] - if (auto ext = dyn_cast(prelhs)) { - if (ext->getOpcode() == Instruction::UIToFP || - ext->getOpcode() == Instruction::SIToFP) { - push(ext); - - Value *condition = pushcse( - B.CreateICmpEQ(ext->getOperand(0), - ConstantInt::get(ext->getOperand(0)->getType(), 0), - "sdivcmp." + cur->getName())); - - Value *fdiv = pushcse( - B.CreateFMulFMF(pushcse(B.CreateFDivFMF( - ConstantFP::get(cur->getType(), 1.0), b, cur)), - ext, cur)); - - Value *sel = pushcse( - B.CreateSelect(condition, ConstantFP::get(cur->getType(), 0.0), - fdiv, "sfdiv." + cur->getName())); - - replaceAndErase(cur, sel); - return "FDivSIToFPProp"; - } - } - // fdiv (select c, 0, a), b -> select c, 0 (fdiv a, b) - if (auto SI = dyn_cast(prelhs)) { - auto tvalC = dyn_cast(SI->getTrueValue()); - auto fvalC = dyn_cast(SI->getFalseValue()); - if ((tvalC && tvalC->isZero()) || (fvalC && fvalC->isZero())) { - push(SI); - auto ntval = - (tvalC && tvalC->isZero()) - ? tvalC - : pushcse(B.CreateFDivFMF(SI->getTrueValue(), b, cur, - "sfdiv2_t." + cur->getName())); - auto nfval = - (fvalC && fvalC->isZero()) - ? fvalC - : pushcse(B.CreateFDivFMF(SI->getFalseValue(), b, cur, - "sfdiv2_f." + cur->getName())); - - // Work around bad fdivfmf, fixed in LLVM 16+ - // https://github.com/llvm/llvm-project/commit/4f3b1c6dd6ef6c7b5bb79f058e3b7ba4bcdf4566 -#if LLVM_VERSION_MAJOR < 16 - for (auto v : {ntval, nfval}) - if (auto I = dyn_cast(v)) - I->setFastMathFlags(cur->getFastMathFlags()); -#endif - - auto res = pushcse(B.CreateSelect(SI->getCondition(), ntval, nfval, - "sfdiv2." + cur->getName())); - - replaceAndErase(cur, res); - return "FDivSelectProp"; - } - } - } - - // div (mul a:not_sparse, b:is_sparse), c -> mul (div, a, c), b:is_sparse - if (cur->getOpcode() == Instruction::FDiv) { - auto c = cur->getOperand(1); - if (auto z = dyn_cast(cur->getOperand(0))) { - if (z->getOpcode() == Instruction::FMul) { - for (int i = 0; i < 2; i++) { - - Value *a = z->getOperand(i); - Value *b = z->getOperand(1 - i); - if (directlySparse(a)) - continue; - if (!directlySparse(b)) - continue; - - Value *inner_fdiv = pushcse(B.CreateFDivFMF(a, c, cur)); - Value *outer_fmul = pushcse(B.CreateFMulFMF(inner_fdiv, b, z)); - push(z); - replaceAndErase(cur, outer_fmul); - return "FDivFMulSparseProp"; - } - } - } - } - - if (cur->getOpcode() == Instruction::FMul) - for (int i = 0; i < 2; i++) { - - Value *prelhs = cur->getOperand(i); - Value *b = cur->getOperand(1 - i); - - // fmul (fmul x:constant, y):z, b:constant . - if (isa(b)) - if (auto z = dyn_cast(prelhs)) { - if (z->getOpcode() == Instruction::FMul) { - for (int j = 0; j < 2; j++) { - auto x = z->getOperand(i); - if (!isa(x)) - continue; - auto y = z->getOperand(1 - i); - Value *inner_fmul = pushcse(B.CreateFMulFMF(x, b, cur)); - Value *outer_fmul = pushcse(B.CreateFMulFMF(inner_fmul, y, z)); - push(z); - replaceAndErase(cur, outer_fmul); - return "FMulFMulConstantReorder"; - } - } - } - - auto integralFloat = [](Value *z) { - if (auto C = dyn_cast(z)) { - APSInt Tmp(64); - bool isExact = false; - C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero, - &isExact); - if (isExact || C->isZero()) { - return true; - } - } - return false; - }; - - // fmul (fmul x:sparse, y):z, b - // 1) If x and y are both sparse, do nothing and let the inner fmul be - // simplified into a single sparse instruction. Thus, we may assume - // y is not sparse. - // 2) if b is sparse, swap it to be fmul (fmul x, b), y so the inner - // sparsity can be simplified. - // 3) otherwise b is not sparse and we should push the sparsity to - // be the outermost value - if (auto z = dyn_cast(prelhs)) { - if (z->getOpcode() == Instruction::FMul) { - for (int j = 0; j < 2; j++) { - auto x = z->getOperand(j); - if (!directlySparse(x)) - continue; - auto y = z->getOperand(1 - j); - if (directlySparse(y)) - continue; - - if (directlySparse(b) || integralFloat(b)) { - push(z); - Value *inner_fmul = pushcse( - B.CreateFMulFMF(x, b, cur, "mulisr." + cur->getName())); - Value *outer_fmul = pushcse( - B.CreateFMulFMF(inner_fmul, y, z, "mulisr." + z->getName())); - replaceAndErase(cur, outer_fmul); - return "FMulFMulSparseReorder"; - } else { - push(z); - Value *inner_fmul = pushcse( - B.CreateFMulFMF(y, b, cur, "mulisp." + cur->getName())); - Value *outer_fmul = pushcse( - B.CreateFMulFMF(inner_fmul, x, z, "mulisp." + z->getName())); - replaceAndErase(cur, outer_fmul); - return "FMulFMulSparsePush"; - } - } - } - } - - /* - auto contains = [](MDNode *MD, Value *V) { - if (!MD) - return false; - for (auto &op : MD->operands()) { - auto V2 = cast(op)->getValue(); - if (V == V2) - return true; - } - return false; - }; - - // fmul (sitofp a), b -> select (a == 0), 0 [noprop fmul ( sitofp a), b] - if (true || !contains(hasMetadata(cur, "enzyme_fmulnoprop"), prelhs)) - if (auto ext = dyn_cast(prelhs)) { - if (ext->getOpcode() == Instruction::UIToFP || - ext->getOpcode() == Instruction::SIToFP) { - push(ext); - - Value *condition = pushcse(B.CreateICmpEQ( - ext->getOperand(0), - ConstantInt::get(ext->getOperand(0)->getType(), 0), - "mulcsicmp." + cur->getName())); - - Value *fmul = pushcse(B.CreateFMulFMF(ext, b, cur)); - if (auto I = dyn_cast(fmul)) { - SmallVector nodes; - if (auto MD = hasMetadata(cur, "enzyme_fmulnoprop")) { - for (auto &M : MD->operands()) { - nodes.push_back(M.get()); - } - } - nodes.push_back(ValueAsMetadata::get(ext)); - I->setMetadata("enzyme_fmulnoprop", - MDNode::get(I->getContext(), nodes)); - } - - Value *sel = pushcse( - B.CreateSelect(condition, ConstantFP::get(cur->getType(), - 0.0), fmul, "mulcsi." + cur->getName())); - - replaceAndErase(cur, sel); - return "FMulSIToFPProp"; - } - } - */ - - // fmul (select c, 0, a), b -> select c, 0 (fmul a, b) - if (auto SI = dyn_cast(prelhs)) { - auto tvalC = dyn_cast(SI->getTrueValue()); - auto fvalC = dyn_cast(SI->getFalseValue()); - if ((tvalC && tvalC->isZero()) || (fvalC && fvalC->isZero())) { - push(SI); - auto ntval = - (tvalC && tvalC->isZero()) - ? tvalC - : pushcse(B.CreateFMulFMF(SI->getTrueValue(), b, cur)); - auto nfval = - (fvalC && fvalC->isZero()) - ? fvalC - : pushcse(B.CreateFMulFMF(SI->getFalseValue(), b, cur)); - auto res = pushcse(B.CreateSelect(SI->getCondition(), ntval, nfval, - "mulsi." + cur->getName())); - - replaceAndErase(cur, res); - return "FMulSelectProp"; - } - } - } - - if (auto icmp = dyn_cast(cur)) { - if (icmp->getOpcode() == Instruction::Xor) { - for (int i = 0; i < 2; i++) { - if (auto C = dyn_cast(icmp->getOperand(i))) { - // !(cmp a, b) -> inverse(cmp), a, b - if (C->isOne()) { - if (auto scmp = dyn_cast(icmp->getOperand(1 - i))) { - auto next = pushcse( - B.CreateCmp(scmp->getInversePredicate(), scmp->getOperand(0), - scmp->getOperand(1), "not." + scmp->getName())); - replaceAndErase(cur, next); - return "NotCmp"; - } - } - } - } - } - } - - // select cmp, (ext tval), (ext fval) -> (cmp & tval) | (!cmp & fval) - if (auto SI = dyn_cast(cur)) { - - Value *trueVal = nullptr; - if (auto C = dyn_cast(SI->getTrueValue())) { - if (C->isZero()) { - trueVal = ConstantInt::getFalse(SI->getContext()); - } - if (C->isExactlyValue(1.0)) { - trueVal = ConstantInt::getTrue(SI->getContext()); - } - } - if (auto ext = dyn_cast(SI->getTrueValue())) { - if (ext->getOperand(0)->getType()->isIntegerTy(1)) - trueVal = ext->getOperand(0); - } - Value *falseVal = nullptr; - if (auto C = dyn_cast(SI->getFalseValue())) { - if (C->isZero()) { - falseVal = ConstantInt::getFalse(SI->getContext()); - } - if (C->isExactlyValue(1.0)) { - falseVal = ConstantInt::getTrue(SI->getContext()); - } - } - if (auto ext = dyn_cast(SI->getFalseValue())) { - if (ext->getOperand(0)->getType()->isIntegerTy(1)) - falseVal = ext->getOperand(0); - } - if (trueVal && falseVal) { - auto ncmp1 = pushcse(B.CreateAnd(SI->getCondition(), trueVal)); - auto notV = pushcse(B.CreateNot(SI->getCondition())); - auto ncmp2 = pushcse(B.CreateAnd(notV, falseVal)); - auto ori = pushcse(B.CreateOr(ncmp1, ncmp2)); - auto ext = pushcse(B.CreateUIToFP(ori, SI->getType())); - replaceAndErase(cur, ext); - return "SelectI1Ext"; - } - } - // select cmp, (i1 tval), (i1 fval) -> (cmp & tval) | (!cmp & fval) - if (cur->getType()->isIntegerTy(1)) - if (auto SI = dyn_cast(cur)) { - auto ncmp1 = pushcse(B.CreateAnd(SI->getCondition(), SI->getTrueValue())); - auto notV = pushcse(B.CreateNot(SI->getCondition())); - auto ncmp2 = pushcse(B.CreateAnd(notV, SI->getFalseValue())); - auto ori = pushcse(B.CreateOr(ncmp1, ncmp2)); - replaceAndErase(cur, ori); - return "SelectI1"; - } - - if (auto PN = dyn_cast(cur)) { - B.SetInsertPoint(PN->getParent()->getFirstNonPHI()); - if (SE.isSCEVable(PN->getType())) { - auto S = SE.getSCEV(PN); - - bool legal = false; - if (auto SV = dyn_cast(S)) { - auto val = SV->getValue(); - legal |= isa(val) || isa(val); - if (auto I = dyn_cast(val)) { - auto L = LI.getLoopFor(I->getParent()); - if ((!L || L->getCanonicalInductionVariable() != I) && I != PN) - legal = true; - } - } - if (isa(S)) { - auto L = LI.getLoopFor(PN->getParent()); - assert(L); - if (L->getCanonicalInductionVariable() != PN) - legal = true; - } - - if (legal) { - for (auto U : cur->users()) { - push(U); - } - auto point = PN->getParent()->getFirstNonPHI(); - auto tmp = cast(pushcse(B.CreatePHI(cur->getType(), 1))); - cur->replaceAllUsesWith(tmp); - cur->eraseFromParent(); - - Value *newIV = nullptr; - { - SCEVExpander Exp(SE, DL, "sparseenzyme"); - // We place that at first non phi as it may produce a non-phi - // instruction and must thus be expanded after all phi's - newIV = Exp.expandCodeFor(S, tmp->getType(), point); - // sadly this doesn't exist on 11 - for (auto I : Exp.getAllInsertedInstructions()) - Q.insert(I); - } - - tmp->replaceAllUsesWith(newIV); - tmp->eraseFromParent(); - return "InductVarSCEV"; - } - } - // phi a, a -> a - { - bool legal = true; - for (size_t i = 1; i < PN->getNumIncomingValues(); i++) { - auto v = PN->getIncomingValue(i); - if (v != PN->getIncomingValue(0)) { - legal = false; - break; - } - } - if (legal) { - auto val = PN->getIncomingValue(0); - replaceAndErase(cur, val); - return "PhiMerge"; - } - } - // phi (idx=0) ? b, a, a -> select (idx == 0), b, a - if (auto L = LI.getLoopFor(PN->getParent())) - if (L->getHeader() == PN->getParent()) - if (auto idx = L->getCanonicalInductionVariable()) - if (auto PH = L->getLoopPreheader()) { - bool legal = idx != PN; - auto ph_idx = PN->getBasicBlockIndex(PH); - assert(ph_idx >= 0); - for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { - if ((int)i == ph_idx) - continue; - auto v = PN->getIncomingValue(i); - if (v != PN->getIncomingValue(1 - ph_idx)) { - legal = false; - break; - } - // The given var must dominate the loop - if (isa(v)) - continue; - if (isa(v)) - continue; - // exception for the induction itself, which we handle specially - if (v == idx) - continue; - auto I = cast(v); - if (!DT.dominates(I, PN)) { - legal = false; - break; - } - } - if (legal) { - auto val = PN->getIncomingValue(1 - ph_idx); - push(val); - if (val == idx) { - val = pushcse( - B.CreateSub(idx, ConstantInt::get(idx->getType(), 1))); - } - - auto val2 = PN->getIncomingValue(ph_idx); - push(val2); - - auto c0 = ConstantInt::get(idx->getType(), 0); - // if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) { - // val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val); - //} else { - auto eq = pushcse(B.CreateICmpEQ(idx, c0)); - val = pushcse( - B.CreateSelect(eq, val2, val, "phisel." + cur->getName())); - //} - - replaceAndErase(cur, val); - return "PhiLoop0Sel"; - } - } - // phi (sitofp a), (sitofp b) -> sitofp (phi a, b) - { - SmallVector negOps; - SmallVector prevNegOps; - bool legal = true; - for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { - auto v = PN->getIncomingValue(i); - if (auto C = dyn_cast(v)) { - APSInt Tmp(64); - bool isExact = false; - C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero, - &isExact); - if (isExact || C->isZero()) { - negOps.push_back(ConstantInt::get(B.getInt64Ty(), Tmp)); - continue; - } - } - if (auto fneg = dyn_cast(v)) { - if (fneg->getOpcode() == Instruction::SIToFP && - cast(fneg->getOperand(0)->getType()) - ->getBitWidth() == 64) { - negOps.push_back(fneg->getOperand(0)); - prevNegOps.push_back(fneg); - continue; - } - } - legal = false; - } - if (legal) { - auto PN2 = cast( - pushcse(B.CreatePHI(B.getInt64Ty(), PN->getNumIncomingValues()))); - PN2->takeName(PN); - for (auto val : llvm::enumerate(negOps)) - PN2->addIncoming(val.value(), PN->getIncomingBlock(val.index())); - - push(PN2); - - auto fneg = pushcse(B.CreateSIToFP(PN2, PN->getType())); - - for (auto I : prevNegOps) - push(I); - replaceAndErase(cur, fneg); - return "PhiSIToFP"; - } - } - // phi (fneg a), (fneg b) -> fneg (phi a, b) - { - SmallVector negOps; - SmallVector prevNegOps; - bool legal = true; - bool hasNeg = false; - for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { - auto v = PN->getIncomingValue(i); - if (auto C = dyn_cast(v)) { - negOps.push_back(C->isZero() ? C : pushcse(B.CreateFNeg(C))); - continue; - } - if (auto fneg = dyn_cast(v)) { - if (fneg->getOpcode() == Instruction::FNeg) { - negOps.push_back(fneg->getOperand(0)); - prevNegOps.push_back(fneg); - continue; - } - } - legal = false; - } - if (legal && hasNeg) { - for (auto val : llvm::enumerate(negOps)) - PN->setIncomingValue(val.index(), val.value()); - - push(PN); - - auto fneg = pushcse(B.CreateFNeg(PN)); - - for (auto &U : cur->uses()) { - if (U.getUser() == fneg) - continue; - push(U.getUser()); - U.set(fneg); - } - for (auto I : prevNegOps) - push(I); - return "PhiFNeg"; - } - } - // phi (neg a), (neg b) -> neg (phi a, b) - { - SmallVector negOps; - SmallVector prevNegOps; - bool legal = true; - bool hasNeg = false; - for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { - auto v = PN->getIncomingValue(i); - if (auto C = dyn_cast(v)) { - negOps.push_back(pushcse(B.CreateNeg(C))); - continue; - } - if (auto fneg = dyn_cast(v)) { - if (auto CI = dyn_cast(fneg->getOperand(0))) - if (fneg->getOpcode() == Instruction::Sub && CI->isZero()) { - negOps.push_back(fneg->getOperand(1)); - prevNegOps.push_back(fneg); - hasNeg = true; - continue; - } - } - legal = false; - } - if (legal && hasNeg) { - for (auto val : llvm::enumerate(negOps)) - PN->setIncomingValue(val.index(), val.value()); - - push(PN); - - auto fneg = pushcse(B.CreateNeg(PN)); - - for (auto &U : cur->uses()) { - if (U.getUser() == fneg) - continue; - push(U.getUser()); - U.set(fneg); - } - for (auto I : prevNegOps) - push(I); - return "PHINeg"; - } - } - // p = phi (mul a, c), (mul b, d) -> mul (phi a, b), (phi c, d) if - // a,b,c != p - { - for (auto code : - {(unsigned)Instruction::Mul, (unsigned)Instruction::Sub, - (unsigned)Instruction::Add, (unsigned)Instruction::ZExt, - (unsigned)Instruction::UIToFP, (unsigned)Instruction::ICmp, - (unsigned)Instruction::FMul, (unsigned)Instruction::Or, - (unsigned)Instruction::And}) { - SmallVector lhsOps; - SmallVector rhsOps; - SmallVector prevOps; - bool legal = true; - bool fast = false; - bool NUW = false; - bool NSW = false; - size_t numOps = 0; - std::optional cmpPredicate; - switch (code) { - case Instruction::FMul: - case Instruction::FSub: - case Instruction::FAdd: - fast = true; - numOps = 2; - break; - case Instruction::Mul: - case Instruction::Add: - NUW = NSW = true; - numOps = 2; - break; - case Instruction::Sub: - NSW = true; - numOps = 2; - break; - case Instruction::ICmp: - case Instruction::FCmp: - case Instruction::Or: - case Instruction::And: - numOps = 2; - break; - case Instruction::ZExt: - case Instruction::UIToFP: - numOps = 1; - break; - default:; - llvm_unreachable("unknown opcode"); - } - bool changed = false; - for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { - auto v = PN->getIncomingValue(i); - if (auto C = dyn_cast(v)) { - if (code == Instruction::ZExt) { - lhsOps.push_back(ConstantInt::getFalse(C->getContext())); - continue; - } else if (C->isZero()) { - rhsOps.push_back(C); - lhsOps.push_back(C); - continue; - } - } - if (auto C = dyn_cast(v)) { - if (code == Instruction::UIToFP) { - if (C->isZero()) { - lhsOps.push_back(ConstantInt::getFalse(C->getContext())); - } - } else if (code == Instruction::FMul || code == Instruction::FSub || - code == Instruction::FAdd) { - if (C->isZero()) { - rhsOps.push_back(C); - lhsOps.push_back(C); - continue; - } - } - } - if (auto fneg = dyn_cast(v)) { - if (fneg->getOpcode() == code) { - switch (code) { - case Instruction::FMul: - case Instruction::FSub: - case Instruction::FAdd: - fast &= fneg->isFast(); - if (fneg->getOperand(0) == PN) - legal = false; - if (fneg->getOperand(1) == PN) - legal = false; - lhsOps.push_back(fneg->getOperand(0)); - rhsOps.push_back(fneg->getOperand(1)); - break; - case Instruction::Mul: - case Instruction::Sub: - case Instruction::Add: - NUW &= fneg->hasNoUnsignedWrap(); - NSW &= fneg->hasNoSignedWrap(); - if (fneg->getOperand(0) == PN) - legal = false; - if (fneg->getOperand(1) == PN) - legal = false; - lhsOps.push_back(fneg->getOperand(0)); - rhsOps.push_back(fneg->getOperand(1)); - break; - case Instruction::Or: - case Instruction::And: - if (fneg->getOperand(0) == PN) - legal = false; - if (fneg->getOperand(1) == PN) - legal = false; - lhsOps.push_back(fneg->getOperand(0)); - rhsOps.push_back(fneg->getOperand(1)); - break; - case Instruction::ICmp: - case Instruction::FCmp: - if (fneg->getOperand(0) == PN) - legal = false; - if (fneg->getOperand(1) == PN) - legal = false; - if (cmpPredicate) { - if (*cmpPredicate != cast(fneg)->getPredicate()) - legal = false; - } else { - cmpPredicate = cast(fneg)->getPredicate(); - } - lhsOps.push_back(fneg->getOperand(0)); - rhsOps.push_back(fneg->getOperand(1)); - break; - case Instruction::ZExt: - case Instruction::UIToFP: - if (cast(fneg->getOperand(0)->getType()) - ->getBitWidth() != 1) - legal = false; - lhsOps.push_back(fneg->getOperand(0)); - break; - default: - llvm_unreachable("unhandled opcode"); - } - prevOps.push_back(fneg); - changed = true; - continue; - } - } - legal = false; - } - - int preheader_fix = -1; - - if (code == Instruction::ICmp || code == Instruction::FCmp) { - if (!cmpPredicate) - legal = false; - auto L = LI.getLoopFor(PN->getParent()); - if (legal && L && L->getLoopPreheader() && - L->getCanonicalInductionVariable() && - L->getHeader() == PN->getParent()) { - auto ph_idx = PN->getBasicBlockIndex(L->getLoopPreheader()); - if (isa(PN->getIncomingValue(ph_idx))) { - lhsOps[ph_idx] = - Constant::getNullValue(lhsOps[1 - ph_idx]->getType()); - rhsOps[ph_idx] = - Constant::getNullValue(rhsOps[1 - ph_idx]->getType()); - preheader_fix = ph_idx; - } - } - for (auto v : lhsOps) - if (v->getType() != lhsOps[0]->getType()) - legal = false; - for (auto v : rhsOps) - if (v->getType() != rhsOps[0]->getType()) - legal = false; - } - - if (legal && changed) { - auto lhsPN = cast(pushcse( - B.CreatePHI(lhsOps[0]->getType(), PN->getNumIncomingValues()))); - PHINode *rhsPN = nullptr; - if (numOps == 2) - rhsPN = cast(pushcse( - B.CreatePHI(rhsOps[0]->getType(), PN->getNumIncomingValues()))); - - for (auto val : llvm::enumerate(lhsOps)) - lhsPN->addIncoming(val.value(), PN->getIncomingBlock(val.index())); - - if (numOps == 2) { - for (auto val : llvm::enumerate(rhsOps)) - rhsPN->addIncoming(val.value(), - PN->getIncomingBlock(val.index())); - } - - Value *fneg = nullptr; - switch (code) { - case Instruction::FMul: - fneg = B.CreateFMul(lhsPN, rhsPN); - if (auto I = dyn_cast(fneg)) - I->setFast(fast); - break; - case Instruction::FAdd: - fneg = B.CreateFAdd(lhsPN, rhsPN); - if (auto I = dyn_cast(fneg)) - I->setFast(fast); - break; - case Instruction::FSub: - fneg = B.CreateFSub(lhsPN, rhsPN); - if (auto I = dyn_cast(fneg)) - I->setFast(fast); - break; - case Instruction::Mul: - fneg = B.CreateMul(lhsPN, rhsPN, "", NUW, NSW); - break; - case Instruction::Add: - fneg = B.CreateAdd(lhsPN, rhsPN, "", NUW, NSW); - break; - case Instruction::Sub: - fneg = B.CreateSub(lhsPN, rhsPN, "", NUW, NSW); - break; - case Instruction::ZExt: - fneg = B.CreateZExt(lhsPN, PN->getType()); - break; - case Instruction::FCmp: - case Instruction::ICmp: - fneg = B.CreateCmp(*cmpPredicate, lhsPN, rhsPN); - break; - case Instruction::UIToFP: - fneg = B.CreateUIToFP(lhsPN, PN->getType()); - break; - case Instruction::Or: - fneg = B.CreateOr(lhsPN, rhsPN); - break; - case Instruction::And: - fneg = B.CreateAnd(lhsPN, rhsPN); - break; - default: - llvm_unreachable("unhandled opcode"); - } - - push(fneg); - - if (preheader_fix != -1) { - auto L = LI.getLoopFor(PN->getParent()); - auto idx = L->getCanonicalInductionVariable(); - auto eq = pushcse( - B.CreateICmpEQ(idx, ConstantInt::get(idx->getType(), 0))); - fneg = - pushcse(B.CreateSelect(eq, PN->getIncomingValue(preheader_fix), - fneg, "phphisel." + cur->getName())); - } - - replaceAndErase(cur, fneg); - return "PHIBinop"; - } - } - } - // phi -> select - if (PN->getNumIncomingValues() == 2) { - for (int i = 0; i < 2; i++) { - auto prev = PN->getIncomingBlock(i); - if (!DT.dominates(prev, PN->getParent())) { - continue; - } - auto br = dyn_cast(prev->getTerminator()); - if (!br) { - continue; - } - if (!br->isConditional()) { - continue; - } - if (br->getSuccessor(0) != PN->getParent()) { - continue; - } - if (br->getSuccessor(1) != PN->getIncomingBlock(1 - i)) { - continue; - } - - Value *specVal = PN->getIncomingValue(1 - i); - SetVector> todo; - todo.insert(specVal); - SetVector toMove; - bool legal = true; - while (!todo.empty()) { - auto cur = *todo.begin(); - todo.erase(todo.begin()); - auto I = dyn_cast(cur); - if (!I) - continue; - if (I->mayReadOrWriteMemory()) { - legal = false; - break; - } - if (DT.dominates(I, PN)) - continue; - for (size_t i = 0; i < I->getNumOperands(); i++) - todo.insert(I->getOperand(i)); - toMove.insert(I); - } - if (!legal) - continue; - for (auto iter = toMove.rbegin(), end = toMove.rend(); iter != end; - iter++) { - (*iter)->moveBefore(br); - } - auto sel = pushcse(B.CreateSelect( - br->getCondition(), PN->getIncomingValueForBlock(prev), - PN->getIncomingValueForBlock(br->getSuccessor(1)), - "tphisel." + cur->getName())); - - replaceAndErase(cur, sel); - return "TPhiSel"; - } - } - } - - if (auto SI = dyn_cast(cur)) { - auto tval = replace(SI->getTrueValue(), SI->getCondition(), - ConstantInt::getTrue(SI->getContext())); - auto fval = replace(SI->getFalseValue(), SI->getCondition(), - ConstantInt::getFalse(SI->getContext())); - if (tval != SI->getTrueValue() || fval != SI->getFalseValue()) { - auto res = pushcse(B.CreateSelect(SI->getCondition(), tval, fval, - "postsel." + SI->getName())); - replaceAndErase(cur, res); - return "SelectReplace"; - } - } - - // and a, b -> and a b[with a true] - if (cur->getOpcode() == Instruction::And) { - auto lhs = replace(cur->getOperand(0), cur->getOperand(1), - ConstantInt::getTrue(cur->getContext())); - if (lhs != cur->getOperand(0)) { - auto res = pushcse( - B.CreateAnd(lhs, cur->getOperand(1), "postand." + cur->getName())); - replaceAndErase(cur, res); - return "AndReplaceLHS"; - } - auto rhs = replace(cur->getOperand(1), cur->getOperand(0), - ConstantInt::getTrue(cur->getContext())); - if (rhs != cur->getOperand(1)) { - auto res = pushcse( - B.CreateAnd(cur->getOperand(0), rhs, "postand." + cur->getName())); - replaceAndErase(cur, res); - return "AndReplaceRHS"; - } - } - - // or a, b -> or a b[with a false] - if (cur->getOpcode() == Instruction::Or) { - auto lhs = replace(cur->getOperand(0), cur->getOperand(1), - ConstantInt::getFalse(cur->getContext())); - if (lhs != cur->getOperand(0)) { - auto res = pushcse( - B.CreateOr(lhs, cur->getOperand(1), "postor." + cur->getName())); - replaceAndErase(cur, res); - return "OrReplaceLHS"; - } - auto rhs = replace(cur->getOperand(1), cur->getOperand(0), - ConstantInt::getFalse(cur->getContext())); - if (rhs != cur->getOperand(1)) { - auto res = pushcse( - B.CreateOr(cur->getOperand(0), rhs, "postor." + cur->getName())); - replaceAndErase(cur, res); - return "OrReplaceRHS"; - } - } - return {}; -} - -class Constraints; -raw_ostream &operator<<(raw_ostream &os, const Constraints &c); - -struct ConstraintComparator { - bool operator()(std::shared_ptr lhs, - std::shared_ptr rhs) const; -}; - -struct ConstraintContext { - ScalarEvolution &SE; - const Loop *loopToSolve; - const SmallVectorImpl &Assumptions; - DominatorTree &DT; - using InnerTy = std::shared_ptr; - using SetTy = std::set; - SetTy seen; - ConstraintContext(ScalarEvolution &SE, const Loop *loopToSolve, - const SmallVectorImpl &Assumptions, - DominatorTree &DT) - : SE(SE), loopToSolve(loopToSolve), Assumptions(Assumptions), DT(DT) { - assert(loopToSolve); - } - ConstraintContext(const ConstraintContext &) = delete; - ConstraintContext(const ConstraintContext &ctx, InnerTy lhs) - : SE(ctx.SE), loopToSolve(ctx.loopToSolve), Assumptions(ctx.Assumptions), - DT(ctx.DT), seen(ctx.seen) { - seen.insert(lhs); - } - ConstraintContext(const ConstraintContext &ctx, InnerTy lhs, InnerTy rhs) - : SE(ctx.SE), loopToSolve(ctx.loopToSolve), Assumptions(ctx.Assumptions), - DT(ctx.DT), seen(ctx.seen) { - seen.insert(lhs); - seen.insert(rhs); - } - bool contains(InnerTy x) const { return seen.count(x) != 0; } -}; - -bool cannotDependOnLoopIV(const SCEV *S, const Loop *L) { - assert(L); - if (isa(S)) - return true; - if (auto M = dyn_cast(S)) { - for (auto o : M->operands()) - if (!cannotDependOnLoopIV(o, L)) - return false; - return true; - } - if (auto M = dyn_cast(S)) { - for (auto o : M->operands()) - if (!cannotDependOnLoopIV(o, L)) - return false; - return true; - } - if (auto M = dyn_cast(S)) { - for (auto o : {M->getLHS(), M->getRHS()}) - if (!cannotDependOnLoopIV(o, L)) - return false; - return true; - } - if (auto UV = dyn_cast(S)) { - auto U = UV->getValue(); - if (isa(U)) - return true; - if (isa(U)) - return true; - auto I = cast(U); - return !L->contains(I->getParent()); - } - if (auto addrec = dyn_cast(S)) { - if (addrec->getLoop() == L) - return false; - for (auto o : addrec->operands()) - if (!cannotDependOnLoopIV(o, L)) - return false; - return true; - } - if (auto expr = dyn_cast(S)) { - return cannotDependOnLoopIV(expr->getOperand(), L); - } - llvm::errs() << " cannot tell if depends on loop iv: " << *S << "\n"; - return false; -} - -const SCEV *evaluateAtLoopIter(const SCEV *V, ScalarEvolution &SE, - const Loop *find, const SCEV *replace) { - assert(find); - if (cannotDependOnLoopIV(V, find)) - return V; - if (auto addrec = dyn_cast(V)) { - if (addrec->getLoop() == find) { - auto V2 = addrec->evaluateAtIteration(replace, SE); - return evaluateAtLoopIter(V2, SE, find, replace); - } - } - if (auto div = dyn_cast(V)) { - auto lhs = evaluateAtLoopIter(div->getLHS(), SE, find, replace); - if (!lhs) - return nullptr; - auto rhs = evaluateAtLoopIter(div->getRHS(), SE, find, replace); - if (!rhs) - return nullptr; - return SE.getUDivExpr(lhs, rhs); - } - return nullptr; -} - -class Constraints : public std::enable_shared_from_this { -public: - const enum class Type { - Union = 0, - Intersect = 1, - Compare = 2, - All = 3, - None = 4 - } ty; - - using InnerTy = std::shared_ptr; - - using SetTy = std::set; - - const SetTy values; - - const SCEV *const node; - // whether equal to the node, or not equal to the node - bool isEqual; - // the loop of the iv comparing against. - const llvm::Loop *const Loop; - // using SetTy = SmallVector; - // using SetTy = SetVector, - // std::set>; - - Constraints() - : ty(Type::Union), values(), node(nullptr), isEqual(false), - Loop(nullptr) {} - -private: - Constraints(const SCEV *v, bool isEqual, const llvm::Loop *Loop, bool) - : ty(Type::Compare), values(), node(v), isEqual(isEqual), Loop(Loop) {} - -public: - static InnerTy make_compare(const SCEV *v, bool isEqual, - const llvm::Loop *Loop, - const ConstraintContext &ctx); - - Constraints(Type t) - : ty(t), values(), node(nullptr), isEqual(false), Loop(nullptr) { - assert(t == Type::All || t == Type::None); - } - Constraints(Type t, const SetTy &c, bool check = true) - : ty(t), values(c), node(nullptr), isEqual(false), Loop(nullptr) { - assert(t != Type::All); - assert(t != Type::None); - assert(c.size() != 0); - assert(c.size() != 1); -#ifndef NDEBUG - SmallVector tmp(c.begin(), c.end()); - for (unsigned i = 0; i < tmp.size(); i++) - for (unsigned j = 0; j < i; j++) - assert(*tmp[i] != *tmp[j]); - if (t == Type::Intersect) { - for (auto &v : c) { - assert(v->ty != Type::Intersect); - } - } - if (t == Type::Union) { - for (auto &v : c) { - assert(v->ty != Type::Union); - } - } - if (t == Type::Intersect && check) { - for (unsigned i = 0; i < tmp.size(); i++) - if (tmp[i]->ty == Type::Compare && tmp[i]->isEqual && tmp[i]->Loop) - for (unsigned j = 0; j < tmp.size(); j++) - if (tmp[j]->ty == Type::Compare) - if (auto s = dyn_cast(tmp[j]->node)) - assert(s->getLoop() != tmp[i]->Loop); - } -#endif - } - - bool operator==(const Constraints &rhs) const { - if (ty != rhs.ty) { - return false; - } - if (node != rhs.node) { - return false; - } - if (isEqual != rhs.isEqual) { - return false; - } - if (Loop != rhs.Loop) { - return false; - } - if (values.size() != rhs.values.size()) { - return false; - } - for (auto pair : llvm::zip(values, rhs.values)) { - if (*std::get<0>(pair) != *std::get<1>(pair)) - return false; - } - return true; - //) && !(rhs.values < values) - /* -for (size_t i=0; i(const Constraints &rhs) const { return rhs < *this; } - bool operator<(const Constraints &rhs) const { - if (ty < rhs.ty) { - return true; - } - if (ty > rhs.ty) { - return false; - } - if (node < rhs.node) { - return true; - } - if (node > rhs.node) { - return false; - } - if (isEqual < rhs.isEqual) { - return true; - } - if (isEqual > rhs.isEqual) { - return false; - } - if (Loop < rhs.Loop) { - return true; - } - if (Loop > rhs.Loop) { - return false; - } - if (values.size() < rhs.values.size()) { - return true; - } - if (values.size() > rhs.values.size()) { - return false; - } - for (auto pair : llvm::zip(values, rhs.values)) { - if (*std::get<0>(pair) < *std::get<1>(pair)) - return true; - if (*std::get<0>(pair) > *std::get<1>(pair)) - return false; - } - return false; - } - unsigned hash() const { - unsigned res = 5 * (unsigned)ty + - DenseMapInfo::getHashValue(node) + isEqual; - res = llvm::detail::combineHashValue(res, (unsigned)(size_t)Loop); - for (auto v : values) - res = llvm::detail::combineHashValue(res, v->hash()); - return res; - } - bool operator!=(const Constraints &rhs) const { return !(*this == rhs); } - static InnerTy all() { - static auto allv = std::make_shared(Type::All); - return allv; - } - static InnerTy none() { - static auto nonev = std::make_shared(Type::None); - return nonev; - } - bool isNone() const { return ty == Type::None; } - bool isAll() const { return ty == Type::All; } - static void insert(SetTy &set, InnerTy ty) { - set.insert(ty); - int mcount = 0; - for (auto &v : set) - if (*v == *ty) - mcount++; - assert(mcount == 1); - /* - for (auto &v : set) - if (*v == *ty) - return; - set.push_back(ty); - */ - } - static SetTy intersect(const SetTy &lhs, const SetTy &rhs) { - SetTy res; - for (auto &v : lhs) - if (rhs.count(v)) - res.insert(v); - return res; - } - static void set_subtract(SetTy &set, const SetTy &rhs) { - for (auto &v : rhs) - if (set.count(v)) - set.erase(v); - /* - for (const auto &val : rhs) - for (auto I = set.begin(); I != set.end(); I++) { - if (**I == *val) { - set.erase(I); - break; - } - } -*/ - } - __attribute__((noinline)) void dump() const { llvm::errs() << *this << "\n"; } - InnerTy notB(const ConstraintContext &ctx) const { - switch (ty) { - case Type::None: - return Constraints::all(); - case Type::All: - return Constraints::none(); - case Type::Compare: - return make_compare(node, !isEqual, Loop, ctx); - case Type::Union: { - // not of or's is and of not's - SetTy next; - for (const auto &v : values) - insert(next, v->notB(ctx)); - if (next.size() == 1) - llvm::errs() << " uold : " << *this << "\n"; - return std::make_shared(Type::Intersect, next); - } - case Type::Intersect: { - // not of and's is or of not's - SetTy next; - for (const auto &v : values) - insert(next, v->notB(ctx)); - if (next.size() == 1) - llvm::errs() << " old : " << *this << "\n"; - return std::make_shared(Type::Union, next); - } - } - return Constraints::none(); - } - InnerTy orB(InnerTy rhs, const ConstraintContext &ctx) const { - auto notLHS = notB(ctx); - if (!notLHS) - return nullptr; - auto notRHS = rhs->notB(ctx); - if (!notRHS) - return nullptr; - auto andV = notLHS->andB(notRHS, ctx); - if (!andV) - return nullptr; - auto res = andV->notB(ctx); - return res; - } - InnerTy andB(const InnerTy rhs, const ConstraintContext &ctx) const { - assert(rhs); - if (*rhs == *this) - return shared_from_this(); - if (rhs->isNone()) - return rhs; - if (rhs->isAll()) - return shared_from_this(); - if (isNone()) - return shared_from_this(); - if (isAll()) - return rhs; - - // llvm::errs() << " anding: " << *this << " with " << *rhs << "\n"; - if (ctx.contains(shared_from_this()) || ctx.contains(rhs)) { - // llvm::errs() << " %%% stopping recursion\n"; - return nullptr; - } - if (ty == Type::Compare && rhs->ty == Type::Compare) { - auto sub = ctx.SE.getMinusSCEV(node, rhs->node); - if (Loop == rhs->Loop) { - // llvm::errs() << " + sameloop, sub=" << *sub << "\n"; - if (auto cst = dyn_cast(sub)) { - // the two solves are equivalent to each other - if (cst->getValue()->isZero()) { - // iv = a and iv = a - // also iv != a and iv != a - if (isEqual == rhs->isEqual) - return shared_from_this(); - else { - // iv = a and iv != a - return Constraints::none(); - } - } else { - // the two solves are guaranteed to be distinct - // iv == 0 and iv == 1 - if (isEqual && rhs->isEqual) { - return Constraints::none(); - - } else if (!isEqual && !rhs->isEqual) { - // iv != 0 and iv != 1 - SetTy vals; - insert(vals, shared_from_this()); - insert(vals, rhs); - return std::make_shared(Type::Intersect, vals); - } else if (!isEqual) { - assert(rhs->isEqual); - // iv != 0 and iv == 1 - return rhs; - ; - } else { - // iv == 0 and iv != 1 - assert(isEqual); - assert(!rhs->isEqual); - return shared_from_this(); - } - } - } else if (isEqual || rhs->isEqual) { - // llvm::errs() << " + botheq\n"; - // eq(i, a) & i ?= b -> eq(i, a) & (a ?= b) - if (auto addrec = dyn_cast(sub)) { - // we want a ?= b, but we can only represent loopvar ?= something - // so suppose a-b is of the form X + Y * lv then a-b ?= 0 is - // X + Y * lv ?= 0 -> lv ?= - X / Y - if (addrec->isAffine()) { - auto X = addrec->getStart(); - auto Y = addrec->getStepRecurrence(ctx.SE); - auto MinusX = X; - - if (isa(Y) && - cast(Y)->getAPInt().isNegative()) - Y = ctx.SE.getNegativeSCEV(Y); - else - MinusX = ctx.SE.getNegativeSCEV(X); - - auto div = ctx.SE.getUDivExpr(MinusX, Y); - auto div_e = ctx.SE.getUDivExactExpr(MinusX, Y); - // in case of inexact division, check that these exactly equal - // for replacement - - if (div == div_e) { - if (isEqual) { - auto res = make_compare(div, /*isEqual*/ rhs->isEqual, - addrec->getLoop(), ctx); - // llvm::errs() << " simplified rhs to: " << *res << "\n"; - return andB(res, ctx); - } else { - assert(rhs->isEqual); - auto res = make_compare(div, /*isEqual*/ isEqual, - addrec->getLoop(), ctx); - // llvm::errs() << " simplified lhs to: " << *res << "\n"; - return rhs->andB(res, ctx); - } - } - } - } - if (isEqual && rhs->Loop && - cannotDependOnLoopIV(sub, ctx.loopToSolve)) { - auto res = make_compare(sub, /*isEqual*/ rhs->isEqual, - /*loop*/ nullptr, ctx); - // llvm::errs() << " simplified(noloop) rhs from " << *rhs - // << " to: " << *res << "\n"; - return andB(res, ctx); - } - if (rhs->isEqual && Loop && - cannotDependOnLoopIV(sub, ctx.loopToSolve)) { - auto res = - make_compare(sub, /*isEqual*/ isEqual, /*loop*/ nullptr, ctx); - // llvm::errs() << " simplified(noloop) lhs from " << *rhs - // << " to: " << *res << "\n"; - return rhs->andB(res, ctx); - } - - llvm::errs() << " warning: potential but unhandled simplification of " - "equalities: " - << *this << " and " << *rhs << " sub: " << *sub << "\n"; - } - } - - if (isEqual) { - if (Loop) - if (auto rep = evaluateAtLoopIter(rhs->node, ctx.SE, Loop, node)) - if (rep != rhs->node) { - auto newrhs = make_compare(rep, rhs->isEqual, rhs->Loop, ctx); - return andB(newrhs, ctx); - } - - // not loop -> node == 0 - if (!Loop) { - for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node), - ctx.SE.getMinusSCEV(rhs->node, node)}) { - // llvm::errs() << " maybe replace lhs: " << *this << " rhs: " << - // *rhs - // << " sub1: " << *sub1 << "\n"; - auto newrhs = make_compare(sub1, rhs->isEqual, rhs->Loop, ctx); - if (*newrhs == *this) - return shared_from_this(); - if (!isa(rhs->node) && isa(sub1)) { - return andB(newrhs, ctx); - } - } - } - } - - if (rhs->isEqual) { - if (rhs->Loop) - if (auto rep = evaluateAtLoopIter(node, ctx.SE, rhs->Loop, rhs->node)) - if (rep != node) { - auto newlhs = make_compare(rep, isEqual, Loop, ctx); - return newlhs->andB(rhs, ctx); - } - - // not loop -> node == 0 - if (!rhs->Loop) { - for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node), - ctx.SE.getMinusSCEV(rhs->node, node)}) { - // llvm::errs() << " maybe replace lhs2: " << *this << " rhs: " << - // *rhs - // << " sub1: " << *sub1 << "\n"; - auto newlhs = make_compare(sub1, isEqual, Loop, ctx); - if (*newlhs == *this) - return shared_from_this(); - if (!isa(node) && isa(sub1)) { - return newlhs->andB(rhs, ctx); - } - } - } - } - - if (!Loop && !rhs->Loop && isEqual == rhs->isEqual) { - if (node == ctx.SE.getNegativeSCEV(rhs->node)) - return shared_from_this(); - } - - SetTy vals; - insert(vals, shared_from_this()); - insert(vals, rhs); - if (vals.size() == 1) { - llvm::errs() << "this: " << *this << " rhs: " << *rhs << "\n"; - } - auto res = std::make_shared(Type::Intersect, vals); - // llvm::errs() << " naiive comp merge: " << *res << "\n"; - return res; - } - if (ty == Type::Intersect && rhs->ty == Type::Intersect) { - auto tmp = shared_from_this(); - for (const auto &v : rhs->values) { - auto tmp2 = tmp->andB(v, ctx); - if (!tmp2) - return nullptr; - tmp = std::move(tmp2); - } - return tmp; - } - if (ty == Type::Intersect && rhs->ty == Type::Compare) { - SetTy vals; - // Force internal merging to do individual compares - bool foldedIn = false; - for (auto en : llvm::enumerate(values)) { - auto i = en.index(); - auto v = en.value(); - assert(v->ty != Type::Intersect); - assert(v->ty != Type::All); - assert(v->ty != Type::None); - assert(v->ty == Type::Compare || v->ty == Type::Union); - if (foldedIn) { - insert(vals, v); - continue; - } - // this is either a compare or a union - auto tmp = rhs->andB(v, ctx); - if (!tmp) - return nullptr; - switch (tmp->ty) { - case Type::Union: - case Type::All: - llvm_unreachable("Impossible"); - case Type::None: - return Constraints::none(); - case Type::Compare: - insert(vals, tmp); - foldedIn = true; - break; - // if intersected, these two were not foldable, try folding into later - case Type::Intersect: { - SetTy fuse; - insert(fuse, rhs); - insert(fuse, v); - - Constraints trivialFuse(Type::Intersect, fuse, false); - - // If this is not just making an intersect of the two operands, - // remerge. - if (trivialFuse != *tmp) { - InnerTy newlhs = Constraints::all(); - bool legal = true; - for (auto en2 : llvm::enumerate(values)) { - auto i2 = en2.index(); - auto v2 = en2.value(); - if (i2 == i) - continue; - auto newlhs2 = newlhs->andB(v2, ctx); - if (!newlhs2) { - legal = false; - break; - } - newlhs = std::move(newlhs2); - } - if (legal) { - return newlhs->andB(tmp, ctx); - } - } - insert(vals, v); - } - } - } - if (!foldedIn) { - insert(vals, rhs); - return std::make_shared(Type::Intersect, vals); - } else { - auto cur = Constraints::all(); - for (auto &iv : vals) { - auto cur2 = cur->andB(iv, ctx); - if (!cur2) - return nullptr; - cur = std::move(cur2); - } - return cur; - } - } - if ((ty == Type::Intersect || ty == Type::Compare) && - rhs->ty == Type::Union) { - SetTy unionVals = rhs->values; - bool changed = false; - SetTy ivVals; - if (ty == Type::Intersect) - ivVals = values; - else - insert(ivVals, shared_from_this()); - - ConstraintContext ctxd(ctx, shared_from_this(), rhs); - - for (const auto &iv : ivVals) { - SetTy nextunionVals; - bool midchanged = false; - for (auto &uv : unionVals) { - auto tmp = iv->andB(uv, ctxd); - if (!tmp) { - midchanged = false; - nextunionVals = unionVals; - break; - } - switch (tmp->ty) { - case Type::None: - case Type::Compare: - case Type::Union: - insert(nextunionVals, tmp); - changed |= tmp != uv; - break; - case Type::Intersect: { - SetTy fuse; - if (uv->ty == Type::Intersect) - fuse = uv->values; - else { - assert(uv->ty == Type::Compare); - insert(fuse, uv); - } - insert(fuse, iv); - - Constraints trivialFuse(Type::Intersect, fuse, false); - if (trivialFuse != *tmp) { - insert(nextunionVals, tmp); - midchanged = true; - break; - } - - insert(nextunionVals, uv); - break; - } - case Type::All: - llvm_unreachable("Impossible"); - } - } - if (midchanged) { - unionVals = nextunionVals; - changed = true; - } - } - - if (changed) { - auto cur = Constraints::none(); - for (auto uv : unionVals) { - cur = cur->orB(uv, ctxd); - if (!cur) - break; - } - - if (*cur != *rhs) - return andB(cur, ctx); - } - - SetTy vals = ivVals; - insert(vals, rhs); - return std::make_shared(Type::Intersect, vals); - } - // Handled above via symmetry - if (rhs->ty == Type::Intersect || rhs->ty == Type::Compare) { - return rhs->andB(shared_from_this(), ctx); - } - // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) - // and (c or e)) - if (ty == Type::Union && rhs->ty == Type::Union) { - if (*this == *rhs->notB(ctx)) { - return Constraints::none(); - } - SetTy intersection = intersect(values, rhs->values); - if (intersection.size() != 0) { - InnerTy other_lhs = remove(intersection); - InnerTy other_rhs = rhs->remove(intersection); - InnerTy remainder; - if (intersection.size() == 1) - remainder = *intersection.begin(); - else { - remainder = std::make_shared(Type::Union, intersection); - } - return remainder->orB(other_lhs->andB(other_rhs, ctx), ctx); - } - - bool changed = false; - SetTy lhsVals = values; - SetTy rhsVals = rhs->values; - - ConstraintContext ctxd(ctx, shared_from_this(), rhs); - - SetTy distributedVals; - for (const auto &l1 : lhsVals) { - bool subchanged = false; - SetTy subDistributedVals; - for (auto &r1 : rhsVals) { - auto tmp = l1->andB(r1, ctxd); - if (!tmp) { - subchanged = false; - break; - } - - if (l1->ty == Type::Intersect || r1->ty == Type::Intersect) { - subchanged = true; - insert(subDistributedVals, tmp); - } else { - - SetTy fuse; - insert(fuse, l1); - insert(fuse, r1); - assert(fuse.size() == 2); - Constraints trivialFuse(Type::Intersect, fuse); - if ((trivialFuse != *tmp) || distributedVals.count(tmp)) { - subchanged = true; - } - insert(subDistributedVals, tmp); - } - } - if (subchanged) { - for (auto sub : subDistributedVals) - insert(distributedVals, sub); - changed = true; - } else { - auto midand = l1->andB(rhs, ctxd); - if (!midand) { - changed = false; - break; - } - insert(distributedVals, midand); - } - } - - if (changed) { - auto cur = Constraints::none(); - bool legal = true; - for (auto &uv : distributedVals) { - auto cur2 = cur->orB(uv, ctxd); - if (!cur2) { - legal = false; - break; - } - cur = std::move(cur2); - } - if (legal) { - return cur; - } - } - - SetTy vals; - insert(vals, shared_from_this()); - insert(vals, rhs); - auto res = std::make_shared(Type::Intersect, vals); - return res; - } - llvm::errs() << " andB this: " << *this << " rhs: " << *rhs << "\n"; - llvm_unreachable("Illegal predicate state"); - } - // what this would be like when removing the following list of constraints - InnerTy remove(const SetTy &sub) const { - assert(ty == Type::Union || ty == Type::Intersect); - SetTy res = values; - set_subtract(res, sub); - // res.set_subtract(sub); - if (res.size() == 0) { - if (ty == Type::Union) - return Constraints::none(); - else - return Constraints::all(); - } else if (res.size() == 1) { - return *res.begin(); - } else { - return std::make_shared(ty, res); - } - } - SmallVector, 1> - allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP, - const ConstraintContext &ctx, IRBuilder<> &B) const; -}; - -void dump(const Constraints &c) { c.dump(); } -void dump(std::shared_ptr c) { c->dump(); } - -bool ConstraintComparator::operator()( - std::shared_ptr lhs, - std::shared_ptr rhs) const { - return *lhs < *rhs; -} - -raw_ostream &operator<<(raw_ostream &os, const Constraints &c) { - switch (c.ty) { - case Constraints::Type::All: - return os << "All"; - case Constraints::Type::None: - return os << "None"; - case Constraints::Type::Union: { - os << "(Union "; - for (auto v : c.values) - os << *v << ", "; - os << ")"; - return os; - } - case Constraints::Type::Intersect: { - os << "(Intersect "; - for (auto v : c.values) - os << *v << ", "; - os << ")"; - return os; - } - case Constraints::Type::Compare: { - if (c.isEqual) - os << "(eq "; - else - os << "(ne "; - os << *c.node << ", L="; - if (c.Loop) - os << c.Loop->getHeader()->getName(); - else - os << "nullptr"; - return os << ")"; - } - } - return os; -} - -SmallVector, 1> -Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP, - const ConstraintContext &ctx, IRBuilder<> &B) const { - switch (ty) { - case Type::None: - return {}; - case Type::All: - llvm::errs() << *this << "\n"; - llvm_unreachable("All not handled"); - case Type::Compare: { - Value *cond = ConstantInt::getTrue(T->getContext()); - if (ctx.loopToSolve != Loop) { - assert(ctx.loopToSolve); - Value *ivVal = Exp.expandCodeFor(node, T, IP); - Value *iv = nullptr; - if (Loop) { - iv = Loop->getCanonicalInductionVariable(); - assert(iv); - } else { - iv = ConstantInt::getNullValue(ivVal->getType()); - } - if (isEqual) - cond = B.CreateICmpEQ(ivVal, iv); - else - cond = B.CreateICmpNE(ivVal, iv); - return {std::make_pair((Value *)nullptr, cond)}; - } - if (isEqual) { - return {std::make_pair(Exp.expandCodeFor(node, T, IP), cond)}; - } - EmitFailure("NoSparsification", IP->getDebugLoc(), IP, - "Negated solution not handled: ", *this); - assert(0); - return {}; - } - case Type::Union: { - SmallVector, 1> vals; - for (auto v : values) - for (auto sol : v->allSolutions(Exp, T, IP, ctx, B)) - vals.push_back(sol); - return vals; - } - case Type::Intersect: { - { - SmallVector vals(values.begin(), values.end()); - ssize_t unionidx = -1; - for (unsigned i = 0; i < vals.size(); i++) { - if (vals[i]->ty == Type::Union) { - unionidx = i; - bool allne = true; - for (auto &v : vals[i]->values) { - if (v->ty != Type::Compare || v->isEqual) { - allne = false; - break; - } - } - if (allne) - break; - } - } - if (unionidx != -1) { - auto others = Constraints::all(); - for (unsigned j = 0; j < vals.size(); j++) - if (unionidx != j) - others = others->andB(vals[j], ctx); - SmallVector, 1> resvals; - for (auto &v : vals[unionidx]->values) { - auto tmp = v->andB(others, ctx); - for (const auto &sol : tmp->allSolutions(Exp, T, IP, ctx, B)) - resvals.push_back(sol); - } - return resvals; - } - } - Value *solVal = nullptr; - Value *cond = ConstantInt::getTrue(T->getContext()); - for (auto v : values) { - auto sols = v->allSolutions(Exp, T, IP, ctx, B); - if (sols.size() != 1) { - llvm::errs() << *this << "\n"; - for (auto s : sols) - if (s.first) - llvm::errs() << " + sol: " << *s.first << " " << *s.second << "\n"; - else - llvm::errs() << " + sol: " << s.first << " " << *s.second << "\n"; - llvm::errs() << " v: " << *v << " this: " << *this << "\n"; - llvm_unreachable("Intersect not handled (solsize>1)"); - } - auto sol = sols[0]; - if (sol.first) { - if (solVal != nullptr) { - llvm::errs() << *this << "\n"; - llvm::errs() << " prevsolVal: " << *solVal << "\n"; - llvm_unreachable("Intersect not handled (prevsolval)"); - } - assert(solVal == nullptr); - solVal = sol.first; - } - cond = B.CreateAnd(cond, sol.second); - } - return {std::make_pair(solVal, cond)}; - } - } - return {}; -} - -constexpr bool SparseDebug = false; -std::shared_ptr -getSparseConditions(bool &legal, Value *val, - std::shared_ptr defaultFloat, - Instruction *scope, const ConstraintContext &ctx) { - if (auto I = dyn_cast(val)) { - // Binary `and` is a bit-wise `umin`. - if (I->getOpcode() == Instruction::And) { - auto lhs = getSparseConditions(legal, I->getOperand(0), - Constraints::all(), I, ctx); - auto rhs = getSparseConditions(legal, I->getOperand(1), - Constraints::all(), I, ctx); - auto res = lhs->andB(rhs, ctx); - assert(res); - assert(ctx.seen.size() == 0); - if (SparseDebug) { - llvm::errs() << " getSparse(and, " << *I << "), lhs(" - << *I->getOperand(0) << ") = " << *lhs << "\n"; - llvm::errs() << " getSparse(and, " << *I << "), rhs(" - << *I->getOperand(1) << ") = " << *rhs << "\n"; - llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n"; - } - return res; - } - - // Binary `or` is a bit-wise `umax`. - if (I->getOpcode() == Instruction::Or) { - auto lhs = getSparseConditions(legal, I->getOperand(0), - Constraints::none(), I, ctx); - auto rhs = getSparseConditions(legal, I->getOperand(1), - Constraints::none(), I, ctx); - auto res = lhs->orB(rhs, ctx); - if (SparseDebug) { - llvm::errs() << " getSparse(or, " << *I << "), lhs(" - << *I->getOperand(0) << ") = " << *lhs << "\n"; - llvm::errs() << " getSparse(or, " << *I << "), rhs(" - << *I->getOperand(1) << ") = " << *rhs << "\n"; - llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n"; - } - return res; - } - - if (I->getOpcode() == Instruction::Xor) { - for (int i = 0; i < 2; i++) { - if (auto C = dyn_cast(I->getOperand(i))) - if (C->isOne()) { - auto pres = - getSparseConditions(legal, I->getOperand(1 - i), - defaultFloat->notB(ctx), scope, ctx); - auto res = pres->notB(ctx); - if (SparseDebug) { - llvm::errs() << " getSparse(not, " << *I << "), prev (" - << *I->getOperand(0) << ") = " << *pres << "\n"; - llvm::errs() << " getSparse(not, " << *I << ") = " << *res - << "\n"; - } - return res; - } - } - } - - if (auto icmp = dyn_cast(I)) { - auto L = ctx.loopToSolve; - auto lhs = ctx.SE.getSCEVAtScope(icmp->getOperand(0), L); - auto rhs = ctx.SE.getSCEVAtScope(icmp->getOperand(1), L); - if (SparseDebug) { - llvm::errs() << " lhs: " << *lhs << "\n"; - llvm::errs() << " rhs: " << *rhs << "\n"; - } - - auto sub1 = ctx.SE.getMinusSCEV(lhs, rhs); - - if (icmp->getPredicate() == ICmpInst::ICMP_EQ || - icmp->getPredicate() == ICmpInst::ICMP_NE) { - if (auto add = dyn_cast(sub1)) { - if (add->isAffine()) { - // 0 === A + B * inc -> -A / B = inc - auto A = add->getStart(); - if (auto B = - dyn_cast(add->getStepRecurrence(ctx.SE))) { - - auto MA = A; - if (B->getAPInt().isNegative()) - B = cast(ctx.SE.getNegativeSCEV(B)); - else - MA = ctx.SE.getNegativeSCEV(A); - auto div = ctx.SE.getUDivExpr(MA, B); - auto div_e = ctx.SE.getUDivExactExpr(MA, B); - if (div == div_e) { - auto res = Constraints::make_compare( - div, icmp->getPredicate() == ICmpInst::ICMP_EQ, - add->getLoop(), ctx); - if (SparseDebug) { - llvm::errs() - << " getSparse(icmp, " << *I << ") = " << *res << "\n"; - } - return res; - } - } - } - } - if (cannotDependOnLoopIV(sub1, ctx.loopToSolve)) { - auto res = Constraints::make_compare( - sub1, icmp->getPredicate() == ICmpInst::ICMP_EQ, nullptr, ctx); - llvm::errs() << " getSparse(icmp_noloop, " << *I << ") = " << *res - << "\n"; - return res; - } - } - if (scope) - EmitWarning("NoSparsification", *I, - " No sparsification: not sparse solvable(icmp): ", *I, - " via ", *sub1); - if (SparseDebug) { - llvm::errs() << " getSparse(icmp_dflt, " << *I - << ") = " << *defaultFloat << "\n"; - } - return defaultFloat; - } - - // cmp x, 1.0 -> false/true - if (auto fcmp = dyn_cast(I)) { - auto res = defaultFloat; - if (SparseDebug) { - llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n"; - } - return res; - - if (fcmp->getPredicate() == CmpInst::FCMP_OEQ || - fcmp->getPredicate() == CmpInst::FCMP_UEQ) { - return Constraints::all(); - } else if (fcmp->getPredicate() == CmpInst::FCMP_ONE || - fcmp->getPredicate() == CmpInst::FCMP_UNE) { - return Constraints::none(); - } - } - } - - if (scope) { - EmitFailure("NoSparsification", scope->getDebugLoc(), scope, - " No sparsification: not sparse solvable: ", *val); - } - legal = false; - return defaultFloat; -} - -Constraints::InnerTy Constraints::make_compare(const SCEV *v, bool isEqual, - const llvm::Loop *Loop, - const ConstraintContext &ctx) { - if (!Loop) { - assert(!isa(v)); - SmallVector noassumption; - ConstraintContext ctx2(ctx.SE, ctx.loopToSolve, noassumption, ctx.DT); - for (auto I : ctx.Assumptions) { - bool legal = true; - auto parsedCond = getSparseConditions(legal, I->getOperand(0), - Constraints::none(), nullptr, ctx2); - bool dominates = ctx.DT.dominates(I, ctx.loopToSolve->getHeader()); - if (legal && dominates) { - if (parsedCond->ty == Type::Compare && !parsedCond->Loop) { - if (parsedCond->node == v || - parsedCond->node == ctx.SE.getNegativeSCEV(v)) { - InnerTy res; - if (parsedCond->isEqual == isEqual) - res = Constraints::all(); - else - res = Constraints::none(); - return res; - } - } - } - } - } - // cannot have negative loop canonical induction var - if (Loop) - if (auto C = dyn_cast(v)) - if (C->getAPInt().isNegative()) { - if (isEqual) - return Constraints::none(); - else - return Constraints::all(); - } - return InnerTy(new Constraints(v, isEqual, Loop, false)); -} - -void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, - SetVector &toDenseBlocks) { - - auto &DT = FAM.getResult(F); - auto &SE = FAM.getResult(F); - auto &LI = FAM.getResult(F); - auto &DL = F.getParent()->getDataLayout(); - - QueueType Q(DT, LI); - { - llvm::SetVector todoBlocks; - for (auto b : toDenseBlocks) { - auto L = LI.getLoopFor(b); - if (L) { - for (auto B : L->getBlocks()) - todoBlocks.insert(B); - } - } - for (auto BB : todoBlocks) - for (auto &I : *BB) - if (!I.getType()->isVoidTy()) { - Q.insert(&I); - assert(Q.contains(&I)); - } - } - - // llvm::errs() << " pre fix inner: " << F << "\n"; - - // Full simplification - while (!Q.empty()) { - auto cur = Q.pop_back_val(); - /* - std::set prev; - for (auto v : Q) - prev.insert(v); - // llvm::errs() << "\n\n\n\n" << F << "\n"; - llvm::errs() << "cur: " << *cur << "\n"; - */ - auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL); - (void)changed; - /* - if (changed) { - llvm::errs() << "changed: " << *changed << "\n"; - - for (auto I : Q) - if (!prev.count(I)) - llvm::errs() << " + " << *I << "\n"; - // llvm::errs() << F << "\n\n"; - } - */ - } - - // llvm::errs() << " post fix inner " << F << "\n"; - - SmallVector, 1> sparseBlocks; - bool legalToSparse = true; - for (auto &B : F) - if (auto br = dyn_cast(B.getTerminator())) - if (br->isConditional()) - for (int bidx = 0; bidx < 2; bidx++) - if (auto uncond_br = - dyn_cast(br->getSuccessor(bidx)->getTerminator())) - if (!uncond_br->isConditional()) - if (uncond_br->getSuccessor(0) == br->getSuccessor(1 - bidx)) { - auto blk = br->getSuccessor(bidx); - int countSparse = 0; - for (auto &I : *blk) { - if (auto CI = dyn_cast(&I)) { - if (auto F = CI->getCalledFunction()) { - if (F->hasFnAttribute("enzyme_sparse_accumulate")) { - countSparse++; - } - } - } - } - if (countSparse == 0) - continue; - if (countSparse > 1) { - legalToSparse = false; - EmitFailure( - "NoSparsification", br->getDebugLoc(), br, "F: ", F, - "\nMultiple distinct sparse stores in same block: ", - *blk); - break; - } - - for (auto &I : *blk) { - if (auto CI = dyn_cast(&I)) { - if (auto F = CI->getCalledFunction()) { - if (F->hasFnAttribute("enzyme_sparse_accumulate")) { - continue; - } - } - if (isReadOnly(CI)) - continue; - } - if (!I.mayWriteToMemory()) - continue; - - legalToSparse = false; - EmitFailure( - "NoSparsification", br->getDebugLoc(), br, "F: ", F, - "\nIllegal writing instruction in sparse block: ", I); - break; - } - - if (!legalToSparse) { - break; - } - - auto L = LI.getLoopFor(blk); - if (!L) { - legalToSparse = false; - EmitFailure("NoSparsification", br->getDebugLoc(), br, - "F: ", F, "\nCould not find loop for: ", *blk); - break; - } - auto idx = L->getCanonicalInductionVariable(); - if (!idx) { - legalToSparse = false; - EmitFailure("NoSparsification", br->getDebugLoc(), br, - "F: ", F, "\nL:", *L, - "\nCould not find loop index: ", *L->getHeader()); - break; - } - assert(idx); - auto preheader = L->getLoopPreheader(); - if (!preheader) { - legalToSparse = false; - EmitFailure("NoSparsification", br->getDebugLoc(), br, - "F: ", F, "\nL:", *L, - "\nCould not find loop preheader"); - break; - } - sparseBlocks.emplace_back(blk, br); - } - - if (!legalToSparse) { - return; - } - - // block, bound, scev for indexset - std::map, - SmallVector>, - 1>>> - forSparsification; - - SmallVector Assumptions; - for (auto &BB : F) - for (auto &I : BB) - if (auto II = dyn_cast(&I)) - if (II->getIntrinsicID() == Intrinsic::assume) - Assumptions.push_back(II); - - bool sawError = false; - - for (auto [blk, br] : sparseBlocks) { - auto L = LI.getLoopFor(blk); - assert(L); - auto idx = L->getCanonicalInductionVariable(); - assert(idx); - auto preheader = L->getLoopPreheader(); - assert(preheader); - - // default is condition avoids sparse, negated is condition goes - // to sparse - auto cond = br->getCondition(); - bool negated = br->getSuccessor(0) == blk; - - bool legal = true; - // Whether the i1 value does not contain any icmp's - std::function onlyDataDependentValues = [&](Value *val) { - auto I = cast(val); - if (I->getOpcode() == Instruction::Or) { - return onlyDataDependentValues(I->getOperand(0)) && - onlyDataDependentValues(I->getOperand(1)); - } - if (I->getOpcode() == Instruction::And) { - return onlyDataDependentValues(I->getOperand(0)) && - onlyDataDependentValues(I->getOperand(1)); - } - if (isa(I)) - return true; - if (isa(I)) - return false; - EmitFailure("NoSparsification", I->getDebugLoc(), I, - " No sparsification: bad datadepedent values check: ", *I); - legal = false; - return true; - }; - - // Simplify variable val which is known to branch away from the - // actual store (if not negated) or to the store (if negated) - // if! negated the result may become more false if negated the - // result may become more true - - // - - // default is condition avoids sparse, negated is condition goes - // to sparse - Instruction *context = - isa(cond) ? cast(cond) : idx; - ConstraintContext cctx(SE, L, Assumptions, DT); - auto solutions = getSparseConditions( - legal, cond, negated ? Constraints::all() : Constraints::none(), - context, cctx); - // llvm::errs() << " solutions pre negate: " << *solutions << "\n"; - if (!negated) { - solutions = solutions->notB(cctx); - } - // llvm::errs() << " solutions post negate: " << *solutions << "\n"; - if (!legal) { - sawError = true; - continue; - } - - if (solutions == Constraints::none() || solutions == Constraints::all()) { - EmitFailure( - "NoSparsification", context->getDebugLoc(), context, "F: ", F, - "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, - "\n No sparsification: not sparse solvable(nosoltn): solutions:", - *solutions); - sawError = true; - } - // llvm::errs() << " found solvable solutions " << *solutions << "\n"; - - if (forSparsification.count(L) == 0) { - { - IRBuilder<> PB(preheader->getTerminator()); - forSparsification[L].first = - std::make_pair(PB.CreatePHI(idx->getType(), 0, "ph.idx"), - PB.CreatePHI(idx->getType(), 0, "loop.idx")); - } - - Value *LoopCount = nullptr; - - IRBuilder<> B(L->getHeader()->getFirstNonPHI()); - { - SCEVExpander Exp(SE, DL, "sparseenzyme"); - auto LoopCountS = SE.getBackedgeTakenCount(L); - LoopCount = B.CreateAdd( - ConstantInt::get(idx->getType(), 1), - Exp.expandCodeFor(LoopCountS, idx->getType(), &blk->front())); - } - Value *inbounds = B.CreateAnd( - B.CreateICmpSLT(idx, LoopCount), - B.CreateICmpSGE(idx, ConstantInt::get(idx->getType(), 0))); - Value *args[] = {inbounds, forSparsification[L].first.second}; - B.CreateCall(F.getParent()->getOrInsertFunction( - "enzyme.sparse.inbounds", B.getVoidTy(), - inbounds->getType(), idx->getType()), - args); - } - - IRBuilder<> B(br); - B.SetInsertPoint(br); - auto nidx = B.CreateICmpEQ( - forSparsification[L].first.first, - ConstantInt::get(idx->getType(), forSparsification[L].second.size())); - // TODO check direction - if (!negated) - nidx = B.CreateNot(nidx); - - br->setCondition(nidx); - forSparsification[L].second.emplace_back(blk, solutions); - } - - if (sawError) { - for (auto &pair : forSparsification) { - for (auto PN : {pair.second.first.first, pair.second.first.second}) { - PN->replaceAllUsesWith(UndefValue::get(PN->getType())); - PN->eraseFromParent(); - } - } - if (llvm::verifyFunction(F, &llvm::errs())) { - llvm::errs() << F << "\n"; - report_fatal_error("function failed verification (6)"); - } - return; - } - - if (forSparsification.size() == 0) { - auto context = &F.getEntryBlock().front(); - EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F, - "\n Found no stores for sparsification"); - return; - } - - for (const auto &pair : forSparsification) { - auto L = pair.first; - auto [PN, inductPN] = pair.second.first; - - auto ph = L->getLoopPreheader(); -#if LLVM_VERSION_MAJOR >= 20 - CodeExtractor ext(L->getBlocks(), &DT); -#else - CodeExtractor ext(DT, *L); -#endif - CodeExtractorAnalysisCache cache(F); - SetVector Inputs, Outputs; - auto F2 = ext.extractCodeRegion(cache, Inputs, Outputs); - assert(F2); - F2->addFnAttr(Attribute::AlwaysInline); - - for (auto U : F2->users()) - cast(U)->eraseFromParent(); - - ssize_t induct_idx = -1; - ssize_t off_idx = -1; - for (auto en : llvm::enumerate(Inputs)) { - if (en.value() == inductPN) - induct_idx = en.index(); - if (en.value() == PN) - off_idx = en.index(); - } - assert(induct_idx != -1); - assert(off_idx != -1); - - auto L2 = LI.getLoopFor(F2->getEntryBlock().getSingleSuccessor()); - auto new_idx = F2->getArg(induct_idx); - auto L2Header = L2->getHeader(); - auto new_lidx = L2->getCanonicalInductionVariable(); - - auto idxty = new_idx->getType(); - - auto new_pn = F2->getArg(off_idx); - // Find all sparse accumulates we weren't meant to handle - { - SmallVector toErase; - // First delete any accumulates in sub loops - for (auto SL : L2->getSubLoops()) - for (auto B : SL->getBlocks()) - for (auto &I : *B) - if (auto CI = dyn_cast(&I)) - if (auto F = CI->getCalledFunction()) { - if (F->hasFnAttribute("enzyme_sparse_accumulate")) { - toErase.push_back(CI); - continue; - } - } - for (auto C : toErase) - C->eraseFromParent(); - toErase.clear(); - // Next delete any accumulates not in latchany loops - for (auto B : L2->getBlocks()) { - bool guarded = false; - if (auto P = B->getSinglePredecessor()) - if (auto S = B->getSingleSuccessor()) - if (auto BI = dyn_cast(P->getTerminator())) - if (BI->isConditional()) - for (size_t i = 0; i < 2; i++) - if (BI->getSuccessor(i) == B && - BI->getSuccessor(1 - i) == S) { - auto val = BI->getCondition(); - if (auto xori = dyn_cast(val)) - if (xori->getOpcode() == Instruction::Xor) - val = xori->getOperand(0); - if (auto cmp = dyn_cast(val)) - if (cmp->getOperand(0) == new_pn || - cmp->getOperand(1) == new_pn) - guarded = true; - } - if (guarded) - continue; - for (auto &I : *B) - if (auto CI = dyn_cast(&I)) - if (auto F = CI->getCalledFunction()) { - if (F->hasFnAttribute("enzyme_sparse_accumulate")) { - toErase.push_back(CI); - continue; - } - } - } - for (auto C : toErase) - C->eraseFromParent(); - toErase.clear(); - } - - auto guard = L2->getLoopLatch()->getTerminator(); - assert(guard); - IRBuilder<> G(guard); - G.CreateRetVoid(); - guard->eraseFromParent(); - new_lidx->replaceAllUsesWith(new_idx); - new_lidx->eraseFromParent(); - - auto phterm = ph->getTerminator(); - IRBuilder<> B(phterm); - - // We extracted code, reset analyses. - /* - DT.reset(); - SE.forgetAllLoops(); - */ - - for (auto en : llvm::enumerate(pair.second.second)) { - auto off = en.index(); - auto &solutions = en.value().second; - ConstraintContext ctx(SE, L, Assumptions, DT); - SCEVExpander Exp(SE, DL, "sparseenzyme", /*preservelcssa*/ false); - auto sols = solutions->allSolutions(Exp, idxty, phterm, ctx, B); - SmallVector prevSols; - for (auto [sol, condition] : sols) { - SmallVector args(Inputs.begin(), Inputs.end()); - args[off_idx] = ConstantInt::get(idxty, off); - args[induct_idx] = sol; - for (auto sol2 : prevSols) - condition = B.CreateAnd(condition, B.CreateICmpNE(sol, sol2)); - prevSols.push_back(sol); - auto BB = B.GetInsertBlock(); - auto B2 = BB->splitBasicBlock(B.GetInsertPoint(), "poststore"); - B2->moveAfter(BB); - BB->getTerminator()->eraseFromParent(); - B.SetInsertPoint(BB); - auto callB = BasicBlock::Create(BB->getContext(), "tostore", - BB->getParent(), B2); - B.CreateCondBr(condition, callB, B2); - B.SetInsertPoint(callB); - B.CreateCall(F2, args); - B.CreateBr(B2); - B.SetInsertPoint(B2->getTerminator()); - } - auto blk = en.value().first; - auto term = blk->getTerminator(); - IRBuilder<> B2(blk); - B2.CreateRetVoid(); - term->eraseFromParent(); - } - - PN->eraseFromParent(); - - for (auto &I : *L2Header) { - auto boundsCheck = dyn_cast(&I); - if (!boundsCheck) - continue; - auto BF = boundsCheck->getCalledFunction(); - if (!BF) - continue; - if (BF->getName() != "enzyme.sparse.inbounds") - continue; - - auto boundsCond = boundsCheck->getArgOperand(0); - - auto next = L2Header->splitBasicBlock(boundsCheck); - - auto exit = BasicBlock::Create(F2->getContext(), "bounds.exit", F2, - L2Header->getNextNode()); - { - IRBuilder B(exit); - B.CreateRetVoid(); - } - L2Header->getTerminator()->eraseFromParent(); - - { - IRBuilder B(L2Header); - B.CreateCondBr(boundsCond, next, exit); - } - boundsCheck->eraseFromParent(); - inductPN->eraseFromParent(); - - break; - } - } - - for (auto &F2 : F.getParent()->functions()) { - if (startsWith(F2.getName(), "__enzyme_product")) { - SmallVector toErase; - for (llvm::User *I : F2.users()) { - auto CB = cast(I); - IRBuilder<> B(CB); - B.setFastMathFlags(getFast()); - Value *res = nullptr; - for (auto v : callOperands(CB)) { - if (res == nullptr) - res = v; - else { - res = B.CreateFMul(res, v); - } - } - CB->replaceAllUsesWith(res); - toErase.push_back(CB); - } - for (auto CB : toErase) - CB->eraseFromParent(); - } else if (startsWith(F2.getName(), "__enzyme_sum")) { - SmallVector toErase; - for (llvm::User *I : F2.users()) { - auto CB = cast(I); - IRBuilder<> B(CB); - B.setFastMathFlags(getFast()); - Value *res = nullptr; - for (auto v : callOperands(CB)) { - if (res == nullptr) - res = v; - else { - res = B.CreateFAdd(res, v); - } - } - CB->replaceAllUsesWith(res); - toErase.push_back(CB); - } - for (auto CB : toErase) - CB->eraseFromParent(); - } - } -} - -void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F, - const llvm::DataLayout &DL) { - auto load_fn = cast(getBaseObject(CI->getArgOperand(0))); - auto store_fn = cast(getBaseObject(CI->getArgOperand(1))); - size_t argstart = 2; - size_t num_args = CI->arg_size(); - SmallVector, 1> users; - - for (auto U : CI->users()) { - users.push_back(std::make_pair(cast(U), CI)); - } - IntegerType *intTy = IntegerType::get(CI->getContext(), 64); - auto toInt = [&](IRBuilder<> &B, llvm::Value *V) { - if (auto PT = dyn_cast(V->getType())) { - if (PT->getAddressSpace() != 0) { -#if LLVM_VERSION_MAJOR < 17 - if (CI->getContext().supportsTypedPointers()) { - V = B.CreateAddrSpaceCast( - V, PointerType::getUnqual(PT->getPointerElementType())); - } else { - V = B.CreateAddrSpaceCast(V, - PointerType::getUnqual(PT->getContext())); - } -#else - V = B.CreateAddrSpaceCast(V, PointerType::getUnqual(PT->getContext())); -#endif - } - return B.CreatePtrToInt(V, intTy); - } - auto IT = cast(V->getType()); - if (IT == intTy) - return V; - return B.CreateZExtOrTrunc(V, intTy); - }; - SmallVector toErase; - - ValueToValueMapTy replacements; - replacements[CI] = Constant::getNullValue(CI->getType()); - Instruction *remaining = nullptr; - while (users.size()) { - auto pair = users.back(); - users.pop_back(); - auto U = pair.first; - auto val = pair.second; - if (replacements.count(U)) - continue; - - IRBuilder B(U); - if (auto CI = dyn_cast(U)) { - for (auto U : CI->users()) { - users.push_back(std::make_pair(cast(U), CI)); - } - auto rep = - B.CreateCast(CI->getOpcode(), replacements[val], CI->getDestTy()); - if (auto I = dyn_cast(rep)) - I->setDebugLoc(CI->getDebugLoc()); - replacements[CI] = rep; - continue; - } - if (auto SI = dyn_cast(U)) { - for (auto U : SI->users()) { - users.push_back(std::make_pair(cast(U), SI)); - } - auto tval = SI->getTrueValue(); - auto fval = SI->getFalseValue(); - auto rep = B.CreateSelect( - SI->getCondition(), - replacements.count(tval) ? (Value *)replacements[tval] : tval, - replacements.count(fval) ? (Value *)replacements[fval] : fval); - if (auto I = dyn_cast(rep)) - I->setDebugLoc(SI->getDebugLoc()); - replacements[SI] = rep; - continue; - } - /* - if (auto CI = dyn_cast(U)) { - for (auto U : CI->users()) { - users.push_back(std::make_pair(cast(U), CI)); - } - continue; - } - */ - if (auto CI = dyn_cast(U)) { - auto funcName = getFuncNameFromCall(CI); - if (funcName == "julia.pointer_from_objref") { - for (auto U : CI->users()) { - users.push_back(std::make_pair(cast(U), CI)); - } - auto *F = CI->getCalledOperand(); - - SmallVector args; - for (auto &arg : CI->args()) - args.push_back(replacements[arg]); - - auto FT = CI->getFunctionType(); - - auto cal = cast(B.CreateCall(FT, F, args)); - cal->setCallingConv(CI->getCallingConv()); - cal->setDebugLoc(CI->getDebugLoc()); - replacements[CI] = cal; - continue; - } - } - if (auto CI = dyn_cast(U)) { - for (auto U : CI->users()) { - users.push_back(std::make_pair(cast(U), CI)); - } - SmallVector inds; - bool allconst = true; - for (auto &ind : CI->indices()) { - if (!isa(ind)) { - allconst = false; - } - inds.push_back(ind); - } - Value *gep; - - if (inds.size() == 1) { - gep = ConstantInt::get( - intTy, (DL.getTypeSizeInBits(CI->getSourceElementType()) + 7) / 8); - gep = B.CreateMul(intTy == inds[0]->getType() - ? inds[0] - : B.CreateZExtOrTrunc(inds[0], intTy), - gep, "", true, true); - gep = B.CreateAdd(B.CreatePtrToInt(replacements[val], intTy), gep); - gep = B.CreateIntToPtr(gep, CI->getType()); - } else if (!allconst) { - gep = B.CreateGEP(CI->getSourceElementType(), replacements[val], inds); - if (auto ge = cast(gep)) - ge->setIsInBounds(CI->isInBounds()); - } else { - APInt ai(64, 0); - CI->accumulateConstantOffset(DL, ai); - gep = B.CreateIntToPtr(ConstantInt::get(intTy, ai), CI->getType()); - } - if (auto I = dyn_cast(gep)) - I->setDebugLoc(CI->getDebugLoc()); - replacements[CI] = gep; - continue; - } - if (auto LI = dyn_cast(U)) { - auto diff = toInt(B, replacements[LI->getPointerOperand()]); - SmallVector args; - args.push_back(diff); - for (size_t i = argstart; i < num_args; i++) - args.push_back(CI->getArgOperand(i)); - - if (load_fn->getFunctionType()->getNumParams() != args.size()) { - auto fnName = load_fn->getName(); - auto found_numargs = load_fn->getFunctionType()->getNumParams(); - auto expected_numargs = args.size(); - EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, - " incorrect number of arguments to loader function ", - fnName, " expected ", expected_numargs, " found ", - found_numargs, " - ", *load_fn->getFunctionType()); - continue; - } else { - bool tocontinue = false; - for (size_t i = 0; i < args.size(); i++) { - if (load_fn->getFunctionType()->getParamType(i) != - args[i]->getType()) { - auto fnName = load_fn->getName(); - EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, - " incorrect type of argument ", i, - " to loader function ", fnName, " expected ", - *args[i]->getType(), " found ", - load_fn->getFunctionType()->params()[i]); - tocontinue = true; - args[i] = UndefValue::get(args[i]->getType()); - } - } - if (tocontinue) - continue; - } - CallInst *call = B.CreateCall(load_fn, args); - call->setDebugLoc(LI->getDebugLoc()); - Value *tmp = call; - if (tmp->getType() != LI->getType()) { - if (CastInst::castIsValid(Instruction::BitCast, tmp, LI->getType())) - tmp = B.CreateBitCast(tmp, LI->getType()); - else { - auto fnName = load_fn->getName(); - EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, - " incorrect return type of loader function ", fnName, - " expected ", *LI->getType(), " found ", - *call->getType()); - tmp = UndefValue::get(LI->getType()); - } - } - LI->replaceAllUsesWith(tmp); - - if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) { - InlineFunctionInfo IFI; - InlineFunction(*call, IFI); - } - toErase.push_back(LI); - continue; - } - if (auto SI = dyn_cast(U)) { - assert(SI->getValueOperand() != val); - auto diff = toInt(B, replacements[SI->getPointerOperand()]); - SmallVector args; - args.push_back(SI->getValueOperand()); - auto sty = store_fn->getFunctionType()->getParamType(0); - if (args[0]->getType() != store_fn->getFunctionType()->getParamType(0)) { - if (CastInst::castIsValid(Instruction::BitCast, args[0], sty)) - args[0] = B.CreateBitCast(args[0], sty); - else { - auto args0ty = args[0]->getType(); - EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, - " first argument of store function must be the type of " - "the store found fn arg type ", - *sty, " expected ", *args0ty); - args[0] = UndefValue::get(sty); - } - } - args.push_back(diff); - for (size_t i = argstart; i < num_args; i++) - args.push_back(CI->getArgOperand(i)); - - if (store_fn->getFunctionType()->getNumParams() != args.size()) { - auto fnName = store_fn->getName(); - auto found_numargs = store_fn->getFunctionType()->getNumParams(); - auto expected_numargs = args.size(); - EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, - " incorrect number of arguments to store function ", fnName, - " expected ", expected_numargs, " found ", found_numargs, - " - ", *store_fn->getFunctionType()); - continue; - } else { - bool tocontinue = false; - for (size_t i = 0; i < args.size(); i++) { - if (store_fn->getFunctionType()->getParamType(i) != - args[i]->getType()) { - auto fnName = store_fn->getName(); - EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, - " incorrect type of argument ", i, - " to storeer function ", fnName, " expected ", - *args[i]->getType(), " found ", - store_fn->getFunctionType()->params()[i]); - tocontinue = true; - args[i] = UndefValue::get(args[i]->getType()); - } - } - if (tocontinue) - continue; - } - auto call = B.CreateCall(store_fn, args); - call->setDebugLoc(SI->getDebugLoc()); - if (store_fn->hasFnAttribute(Attribute::AlwaysInline)) { - InlineFunctionInfo IFI; - InlineFunction(*call, IFI); - } - toErase.push_back(SI); - continue; - } - remaining = U; - } - for (auto U : toErase) - U->eraseFromParent(); - - if (!remaining) { - CI->replaceAllUsesWith(Constant::getNullValue(CI->getType())); - CI->eraseFromParent(); - } else if (replaceAll) { - EmitFailure("IllegalSparse", remaining->getDebugLoc(), remaining, - " Illegal remaining use (", *remaining, ") of todense (", *CI, - ") in function ", *F); - } -} - -bool LowerSparsification(llvm::Function *F, bool replaceAll) { - auto &DL = F->getParent()->getDataLayout(); - bool changed = false; - SmallVector todo; - SetVector toDenseBlocks; - for (auto &BB : *F) { - for (auto &I : BB) { - if (auto CI = dyn_cast(&I)) { - if (getFuncNameFromCall(CI).contains("__enzyme_todense")) { - todo.push_back(CI); - toDenseBlocks.insert(&BB); - } - } - } - } - for (auto CI : todo) { - changed = true; - replaceToDense(CI, replaceAll, F, DL); - } - todo.clear(); - - if (changed && EnzymeAutoSparsity) { - PassBuilder PB; - LoopAnalysisManager LAM; - FunctionAnalysisManager FAM; - CGSCCAnalysisManager CGAM; - ModuleAnalysisManager MAM; - PB.registerModuleAnalyses(MAM); - PB.registerFunctionAnalyses(FAM); - PB.registerLoopAnalyses(LAM); - PB.registerCGSCCAnalyses(CGAM); - PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); - - SimplifyCFGPass(SimplifyCFGOptions()).run(*F, FAM); - InstCombinePass().run(*F, FAM); - // required to make preheaders - LoopSimplifyPass().run(*F, FAM); - fixSparseIndices(*F, FAM, toDenseBlocks); - } - - for (auto &BB : *F) { - for (auto &I : BB) { - if (auto CI = dyn_cast(&I)) { - if (getFuncNameFromCall(CI).contains("__enzyme_post_sparse_todense")) { - todo.push_back(CI); - } - } - } - } - for (auto CI : todo) { - changed = true; - replaceToDense(CI, replaceAll, F, DL); - } - return changed; -} diff --git a/enzyme/Enzyme/FunctionUtils.h b/enzyme/Enzyme/FunctionUtils.h deleted file mode 100644 index 4d751c102ea9..000000000000 --- a/enzyme/Enzyme/FunctionUtils.h +++ /dev/null @@ -1,406 +0,0 @@ -//===- FunctionUtils.h - Declaration of function utilities ---------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file declares utilities on LLVM Functions that are used as part of the -// AD process. -// -//===----------------------------------------------------------------------===// -#ifndef ENZYME_FUNCTION_UTILS_H -#define ENZYME_FUNCTION_UTILS_H - -#include -#include - -#include - -#if LLVM_VERSION_MAJOR >= 16 -#define private public -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#undef private -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "Utils.h" - -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/LoopAnalysisManager.h" -#include "llvm/Analysis/TargetLibraryInfo.h" - -#include "llvm/IR/Function.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" - -#include "llvm/IR/Instructions.h" -#include "llvm/Transforms/Utils/ValueMapper.h" - -#include "llvm/ADT/STLExtras.h" - -//; - -extern "C" { -extern llvm::cl::opt EnzymeAlwaysInlineDiff; -} - -class PreProcessCache { -public: - PreProcessCache(); - PreProcessCache(PreProcessCache &) = delete; - // Using the default move constructor will botch the FAM/MAM proxy passes - // since now the new location of FAM/MAM will not be used. Therefore, use a - // custom move constructor and default initialize these, and move the - // cache/origin maps. - PreProcessCache(PreProcessCache &&prev) : PreProcessCache() { - cache = std::move(prev.cache); - CloneOrigin = std::move(prev.CloneOrigin); - }; - - llvm::LoopAnalysisManager LAM; - llvm::FunctionAnalysisManager FAM; - llvm::ModuleAnalysisManager MAM; - - std::map, llvm::Function *> cache; - std::map CloneOrigin; - - llvm::Function *preprocessForClone(llvm::Function *F, DerivativeMode mode); - - llvm::AAResults &getAAResultsFromFunction(llvm::Function *NewF); - - llvm::Function *CloneFunctionWithReturns( - DerivativeMode mode, unsigned width, llvm::Function *&F, - llvm::ValueToValueMapTy &ptrInputs, - llvm::ArrayRef constant_args, - llvm::SmallPtrSetImpl &constants, - llvm::SmallPtrSetImpl &nonconstant, - llvm::SmallPtrSetImpl &returnvals, ReturnType returnValue, - DIFFE_TYPE returnType, const llvm::Twine &name, - llvm::ValueMap *VMapO, - bool diffeReturnArg, llvm::Type *additionalArg = nullptr); - - void ReplaceReallocs(llvm::Function *NewF, bool mem2reg = false); - void LowerAllocAddr(llvm::Function *NewF); - void AlwaysInline(llvm::Function *NewF); - void optimizeIntermediate(llvm::Function *F); - - void clear(); -}; - -class GradientUtils; - -static inline void -getExitBlocks(const llvm::Loop *L, - llvm::SmallPtrSetImpl &ExitBlocks) { - llvm::SmallVector PotentialExitBlocks; - L->getExitBlocks(PotentialExitBlocks); - for (auto a : PotentialExitBlocks) { - - llvm::SmallVector tocheck; - llvm::SmallPtrSet checked; - tocheck.push_back(a); - - bool isExit = false; - - while (tocheck.size()) { - auto foo = tocheck.back(); - tocheck.pop_back(); - if (checked.count(foo)) { - isExit = true; - goto exitblockcheck; - } - checked.insert(foo); - if (auto bi = llvm::dyn_cast(foo->getTerminator())) { - for (auto nb : bi->successors()) { - if (L->contains(nb)) - continue; - tocheck.push_back(nb); - } - } else if (llvm::isa(foo->getTerminator())) { - continue; - } else { - isExit = true; - goto exitblockcheck; - } - } - - exitblockcheck: - if (isExit) { - ExitBlocks.insert(a); - } - } -} - -static inline llvm::SmallVector -getLatches(const llvm::Loop *L, - const llvm::SmallPtrSetImpl &ExitBlocks) { - llvm::BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) { - llvm::errs() << *L->getHeader()->getParent() << "\n"; - llvm::errs() << *L->getHeader() << "\n"; - llvm::errs() << *L << "\n"; - } - assert(Preheader && "requires preheader"); - - // Find latch, defined as a (perhaps unique) block in loop that branches to - // exit block - llvm::SmallVector Latches; - for (llvm::BasicBlock *ExitBlock : ExitBlocks) { - for (llvm::BasicBlock *pred : llvm::predecessors(ExitBlock)) { - if (L->contains(pred)) { - if (std::find(Latches.begin(), Latches.end(), pred) != Latches.end()) - continue; - Latches.push_back(pred); - } - } - } - return Latches; -} - -// TODO note this doesn't go through [loop, unreachable], and we could get more -// performance by doing this can consider doing some domtree magic potentially -static inline llvm::SmallPtrSet -getGuaranteedUnreachable(llvm::Function *F) { - llvm::SmallPtrSet knownUnreachables; - if (F->empty()) - return knownUnreachables; - std::deque todo; - for (auto &BB : *F) { - todo.push_back(&BB); - } - - while (!todo.empty()) { - llvm::BasicBlock *next = todo.front(); - todo.pop_front(); - - if (knownUnreachables.find(next) != knownUnreachables.end()) - continue; - - if (llvm::isa(next->getTerminator())) - continue; - - if (llvm::isa(next->getTerminator())) { - knownUnreachables.insert(next); - for (llvm::BasicBlock *Pred : predecessors(next)) { - todo.push_back(Pred); - } - continue; - } - - // Assume resumes don't happen - // TODO consider EH - if (llvm::isa(next->getTerminator())) { - knownUnreachables.insert(next); - for (llvm::BasicBlock *Pred : predecessors(next)) { - todo.push_back(Pred); - } - continue; - } - - bool unreachable = true; - for (llvm::BasicBlock *Succ : llvm::successors(next)) { - if (knownUnreachables.find(Succ) == knownUnreachables.end()) { - unreachable = false; - break; - } - } - - if (!unreachable) - continue; - knownUnreachables.insert(next); - for (llvm::BasicBlock *Pred : llvm::predecessors(next)) { - todo.push_back(Pred); - } - continue; - } - - return knownUnreachables; -} - -enum class UseReq { - Need, - Recur, - Cached, -}; -static inline void calculateUnusedValues( - const llvm::Function &oldFunc, - llvm::SmallPtrSetImpl &unnecessaryValues, - llvm::SmallPtrSetImpl &unnecessaryInstructions, - bool returnValue, llvm::function_ref valneeded, - llvm::function_ref instneeded, - llvm::function_ref - useneeded) { - - std::deque todo; - - for (const llvm::BasicBlock &BB : oldFunc) { - if (auto ri = llvm::dyn_cast(BB.getTerminator())) { - if (!returnValue) { - unnecessaryInstructions.insert(ri); - } - unnecessaryValues.insert(ri); - } - for (auto &inst : BB) { - if (&inst == BB.getTerminator()) - continue; - todo.push_back(&inst); - } - } - - while (!todo.empty()) { - auto inst = todo.front(); - todo.pop_front(); - - if (unnecessaryInstructions.count(inst)) { - assert(unnecessaryValues.count(inst)); - continue; - } - - if (!unnecessaryValues.count(inst)) { - - if (valneeded(inst)) { - continue; - } - - bool necessaryUse = false; - - llvm::SmallPtrSet seen; - std::deque users; - - for (auto user_dtx : inst->users()) { - if (auto cst = llvm::dyn_cast(user_dtx)) { - if (useneeded(cst, inst)) - users.push_back(cst); - } - } - - while (users.size()) { - auto val = users.front(); - users.pop_front(); - - if (seen.count(val)) - continue; - seen.insert(val); - - if (unnecessaryInstructions.count(val)) - continue; - - switch (instneeded(val)) { - case UseReq::Need: - necessaryUse = true; - break; - case UseReq::Recur: - for (auto user_dtx : val->users()) { - if (auto cst = llvm::dyn_cast(user_dtx)) { - if (useneeded(cst, val)) - users.push_back(cst); - } - } - break; - case UseReq::Cached: - break; - } - if (necessaryUse) - break; - } - - if (necessaryUse) - continue; - - unnecessaryValues.insert(inst); - - for (auto user : inst->users()) { - if (auto usedinst = llvm::dyn_cast(user)) - todo.push_back(usedinst); - } - } - - if (instneeded(inst) == UseReq::Need) - continue; - - unnecessaryInstructions.insert(inst); - - for (auto &operand : inst->operands()) { - if (auto usedinst = llvm::dyn_cast(operand.get())) { - todo.push_back(usedinst); - } - } - } - - if (false && endsWith(oldFunc.getName(), "subfn")) { - llvm::errs() << "Prepping values for: " << oldFunc.getName() - << " returnValue: " << returnValue << "\n"; - for (auto v : unnecessaryInstructions) { - llvm::errs() << "+ unnecessaryInstructions: " << *v << "\n"; - } - for (auto v : unnecessaryValues) { - llvm::errs() << "+ unnecessaryValues: " << *v << "\n"; - } - llvm::errs() << "\n"; - } -} - -static inline void calculateUnusedStores( - const llvm::Function &oldFunc, - llvm::SmallPtrSetImpl &unnecessaryStores, - llvm::function_ref needStore) { - - std::deque todo; - - for (const llvm::BasicBlock &BB : oldFunc) { - for (auto &inst : BB) { - if (&inst == BB.getTerminator()) - continue; - todo.push_back(&inst); - } - } - - while (!todo.empty()) { - auto inst = todo.front(); - todo.pop_front(); - - if (unnecessaryStores.count(inst)) { - continue; - } - - if (needStore(inst)) - continue; - - unnecessaryStores.insert(inst); - } -} - -void RecursivelyReplaceAddressSpace(llvm::Value *AI, llvm::Value *rep, - bool legal); - -void ReplaceFunctionImplementation(llvm::Module &M); - -/// Is the use of value val as an argument of call CI potentially captured -bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val); - -llvm::FunctionType *getFunctionTypeForClone( - llvm::FunctionType *FTy, DerivativeMode mode, unsigned width, - llvm::Type *additionalArg, llvm::ArrayRef constant_args, - bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE returnType); - -/// Lower __enzyme_todense, returning if changed. -bool LowerSparsification(llvm::Function *F, bool replaceAll = true); - -#endif diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp deleted file mode 100644 index f97ded97719c..000000000000 --- a/enzyme/Enzyme/GradientUtils.cpp +++ /dev/null @@ -1,9704 +0,0 @@ -//===- GradientUtils.cpp - Helper class and utilities for AD ---------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file define two helper classes GradientUtils and subclass -// DiffeGradientUtils. These classes contain utilities for managing the cache, -// recomputing statements, and in the case of DiffeGradientUtils, managing -// adjoint values and shadow pointers. -// -//===----------------------------------------------------------------------===// - -#include -#include -#include -#include - -#include "GradientUtils.h" -#include "MustExitScalarEvolution.h" -#include "Utils.h" - -#include "DifferentialUseAnalysis.h" -#include "LibraryFuncs.h" -#include "TypeAnalysis/TBAA.h" - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" - -#include "llvm/Support/AMDGPUMetadata.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/TimeProfiler.h" - -#if LLVM_VERSION_MAJOR >= 14 -#define addAttribute addAttributeAtIndex -#define hasAttribute hasAttributeAtIndex -#endif - -using namespace llvm; - -StringMap &, CallInst *, ArrayRef, - GradientUtils *)>> - shadowHandlers; -StringMap &, Value *)>> shadowErasers; - -StringMap< - std::pair &, CallInst *, GradientUtils &, - Value *&, Value *&, Value *&)>, - std::function &, CallInst *, - DiffeGradientUtils &, Value *)>>> - customCallHandlers; - -StringMap &, CallInst *, GradientUtils &, - Value *&, Value *&)>> - customFwdCallHandlers; - -extern "C" { -llvm::cl::opt - EnzymeNewCache("enzyme-new-cache", cl::init(true), cl::Hidden, - cl::desc("Use new cache decision algorithm")); - -llvm::cl::opt EnzymeMinCutCache("enzyme-mincut-cache", cl::init(true), - cl::Hidden, - cl::desc("Use Enzyme Mincut algorithm")); - -llvm::cl::opt EnzymeLoopInvariantCache( - "enzyme-loop-invariant-cache", cl::init(true), cl::Hidden, - cl::desc("Attempt to hoist cache outside of loop")); - -llvm::cl::opt EnzymeInactiveDynamic( - "enzyme-inactive-dynamic", cl::init(true), cl::Hidden, - cl::desc("Force wholy inactive dynamic loops to have 0 iter reverse pass")); - -llvm::cl::opt - EnzymeSharedForward("enzyme-shared-forward", cl::init(false), cl::Hidden, - cl::desc("Forward Shared Memory from definitions")); - -llvm::cl::opt - EnzymeRegisterReduce("enzyme-register-reduce", cl::init(false), cl::Hidden, - cl::desc("Reduce the amount of register reduce")); -llvm::cl::opt - EnzymeSpeculatePHIs("enzyme-speculate-phis", cl::init(false), cl::Hidden, - cl::desc("Speculatively execute phi computations")); -llvm::cl::opt EnzymeFreeInternalAllocations( - "enzyme-free-internal-allocations", cl::init(true), cl::Hidden, - cl::desc("Always free internal allocations (disable if allocation needs " - "access outside)")); - -llvm::cl::opt - EnzymeRematerialize("enzyme-rematerialize", cl::init(true), cl::Hidden, - cl::desc("Rematerialize allocations/shadows in the " - "reverse rather than caching")); - -llvm::cl::opt - EnzymeVectorSplitPhi("enzyme-vector-split-phi", cl::init(true), cl::Hidden, - cl::desc("Split phis according to vector size")); - -llvm::cl::opt - EnzymePrintDiffUse("enzyme-print-diffuse", cl::init(false), cl::Hidden, - cl::desc("Print differential use analysis")); -} - -SmallVector MD_ToCopy = { - LLVMContext::MD_dbg, - LLVMContext::MD_tbaa, - LLVMContext::MD_tbaa_struct, - LLVMContext::MD_range, - LLVMContext::MD_nonnull, - LLVMContext::MD_dereferenceable, - LLVMContext::MD_dereferenceable_or_null}; - -static bool isPotentialLastLoopValue(llvm::Value *val, - const llvm::BasicBlock *loc, - const llvm::LoopInfo &LI) { - if (llvm::Instruction *inst = llvm::dyn_cast(val)) { - const llvm::Loop *InstLoop = LI.getLoopFor(inst->getParent()); - if (InstLoop == nullptr) { - return false; - } - for (const llvm::Loop *L = LI.getLoopFor(loc); L; L = L->getParentLoop()) { - if (L == InstLoop) - return false; - } - return true; - } - return false; -} - -GradientUtils::GradientUtils( - EnzymeLogic &Logic, Function *newFunc_, Function *oldFunc_, - TargetLibraryInfo &TLI_, TypeAnalysis &TA_, TypeResults TR_, - ValueToValueMapTy &invertedPointers_, - const SmallPtrSetImpl &constantvalues_, - const SmallPtrSetImpl &activevals_, DIFFE_TYPE ReturnActivity, - bool shadowReturnUsed_, ArrayRef ArgDiffeTypes_, - llvm::ValueMap &originalToNewFn_, - DerivativeMode mode, bool runtimeActivity, unsigned width, bool omp) - : CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), - invertedPointers(), - OrigDT(oldFunc_->empty() - ? ((DominatorTree *)nullptr) - : &Logic.PPC.FAM.getResult( - *oldFunc_)), - OrigPDT(oldFunc_->empty() - ? ((PostDominatorTree *)nullptr) - : &Logic.PPC.FAM.getResult( - *oldFunc_)), - OrigLI(oldFunc_->empty() - ? ((LoopInfo *)nullptr) - : &Logic.PPC.FAM.getResult(*oldFunc_)), - OrigSE(oldFunc_->empty() - ? ((ScalarEvolution *)nullptr) - : &Logic.PPC.FAM.getResult( - *oldFunc_)), - notForAnalysis(getGuaranteedUnreachable(oldFunc_)), - ATA(oldFunc_->empty() - ? nullptr - : new ActivityAnalyzer( - Logic.PPC, Logic.PPC.getAAResultsFromFunction(oldFunc_), - notForAnalysis, TLI_, constantvalues_, activevals_, - ReturnActivity)), - overwritten_args_map_ptr(nullptr), unnecessaryValuesP(nullptr), - tid(nullptr), numThreads(nullptr), - OrigAA(oldFunc_->empty() ? ((AAResults *)nullptr) - : &Logic.PPC.getAAResultsFromFunction(oldFunc_)), - TA(TA_), TR(TR_), omp(omp), runtimeActivity(runtimeActivity), - width(width), shadowReturnUsed(shadowReturnUsed_), - ArgDiffeTypes(ArgDiffeTypes_) { - if (oldFunc_->empty()) - return; - if (oldFunc_->getSubprogram()) { - assert(originalToNewFn_.hasMD()); - } - - for (BasicBlock &BB : *oldFunc) { - for (Instruction &I : BB) { - if (auto CI = dyn_cast(&I)) { - originalCalls.push_back(CI); - } - } - } - - originalToNewFn.getMDMap() = originalToNewFn_.getMDMap(); - - if (oldFunc_->getSubprogram()) { - assert(originalToNewFn.hasMD()); - } - for (auto pair : invertedPointers_) { - invertedPointers.insert(std::make_pair( - (const Value *)pair.first, InvertedPointerVH(this, pair.second))); - } - originalToNewFn.insert(originalToNewFn_.begin(), originalToNewFn_.end()); - for (BasicBlock &oBB : *oldFunc) { - for (Instruction &oI : oBB) { - newToOriginalFn[originalToNewFn[&oI]] = &oI; - } - newToOriginalFn[originalToNewFn[&oBB]] = &oBB; - } - for (Argument &oArg : oldFunc->args()) { - newToOriginalFn[originalToNewFn[&oArg]] = &oArg; - } - for (BasicBlock &BB : *newFunc) { - originalBlocks.push_back(&BB); - } - tape = nullptr; - tapeidx = 0; - assert(originalBlocks.size() > 0); - - SmallVector ReturningBlocks; - for (BasicBlock &BB : *oldFunc) { - if (isa(BB.getTerminator())) - ReturningBlocks.push_back(&BB); - } - for (BasicBlock &BB : *oldFunc) { - bool legal = true; - for (auto BRet : ReturningBlocks) { - if (!(BRet == &BB || OrigDT->dominates(&BB, BRet))) { - legal = false; - break; - } - } - if (legal) - BlocksDominatingAllReturns.insert(&BB); - } -} - -// Whether a particular value is neded in rooting the reverse pass -bool GradientUtils::usedInRooting(const llvm::CallBase *orig, - ArrayRef types, - const llvm::Value *val, bool shadow) const { - SmallVector OrigDefs; - orig->getOperandBundlesAsDefs(OrigDefs); - SmallVector Defs; - for (auto bund : OrigDefs) { - // Only handle jl_roots tag (for now). - if (bund.getTag() != "jl_roots") { - errs() << "unsupported tag " << bund.getTag() << " for " << *orig << "\n"; - llvm_unreachable("unsupported tag"); - } - - // In the future we can reduce the number of roots - // we preserve by identifying which operands they - // correspond to. For now, fall back and preserve all - // primals and shadows - // assert(bund.inputs().size() == types.size()); - for (auto inp : bund.inputs()) { - if (inp != val) - continue; - bool anyPrimal = false; - bool anyShadow = false; - for (auto ty : types) { - if (ty == ValueType::Primal || ty == ValueType::Both) - anyPrimal = true; - if (ty == ValueType::Shadow || ty == ValueType::Both) - anyShadow = true; - } - - if (anyPrimal && !shadow) - return true; - if (anyShadow && shadow) - return true; - } - } - return false; -} - -SmallVector -GradientUtils::getInvertedBundles(CallInst *orig, ArrayRef types, - IRBuilder<> &Builder2, bool lookup, - const ValueToValueMapTy &available) { - assert(!(lookup && (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError))); - - SmallVector OrigDefs; - orig->getOperandBundlesAsDefs(OrigDefs); - SmallVector Defs; - for (auto bund : OrigDefs) { - // Only handle jl_roots tag (for now). - if (bund.getTag() != "jl_roots") { - errs() << "unsupported tag " << bund.getTag() << " for " << *orig << "\n"; - llvm_unreachable("unsupported tag"); - } - SmallVector bunds; - // In the future we can reduce the number of roots - // we preserve by identifying which operands they - // correspond to. For now, fall back and preserve all - // primals and shadows - // assert(bund.inputs().size() == types.size()); - for (auto inp : bund.inputs()) { - bool anyPrimal = false; - bool anyShadow = false; - for (auto ty : types) { - if (ty == ValueType::Primal || ty == ValueType::Both) - anyPrimal = true; - if (ty == ValueType::Shadow || ty == ValueType::Both) - anyShadow = true; - } - - if (anyPrimal) { - Value *newv = getNewFromOriginal(inp); - if (lookup) - newv = lookupM(newv, Builder2, available); - bunds.push_back(newv); - } - if (anyShadow && !isConstantValue(inp)) { - Value *shadow = invertPointerM(inp, Builder2); - if (lookup) - shadow = lookupM(shadow, Builder2); - bunds.push_back(shadow); - } - } - Defs.push_back(OperandBundleDef(bund.getTag().str(), bunds)); - } - return Defs; -} - -Value *GradientUtils::getNewIfOriginal(Value *originst) const { - assert(originst); - auto f = originalToNewFn.find(originst); - if (f == originalToNewFn.end()) { - return originst; - } - assert(f != originalToNewFn.end()); - if (f->second == nullptr) { - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *originst << "\n"; - } - assert(f->second); - return f->second; -} - -Value *GradientUtils::ompThreadId() { - if (tid) - return tid; - IRBuilder<> B(inversionAllocs); - - auto FT = FunctionType::get(Type::getInt64Ty(B.getContext()), - ArrayRef(), false); - auto FN = newFunc->getParent()->getOrInsertFunction("omp_get_thread_num", FT); - auto CI = B.CreateCall(FN); - if (auto F = getFunctionFromCall(CI)) { -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesInaccessibleMemory(); - F->setOnlyReadsMemory(); -#else - F->addFnAttr(Attribute::InaccessibleMemOnly); - F->addFnAttr(Attribute::ReadOnly); -#endif - } -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyAccessesInaccessibleMemory(); - CI->setOnlyReadsMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, - Attribute::InaccessibleMemOnly); - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly); -#endif - return tid = CI; -} - -Value *GradientUtils::ompNumThreads() { - if (numThreads) - return numThreads; - IRBuilder<> B(inversionAllocs); - - auto FT = FunctionType::get(Type::getInt64Ty(B.getContext()), - ArrayRef(), false); - auto FN = - newFunc->getParent()->getOrInsertFunction("omp_get_max_threads", FT); - auto CI = B.CreateCall(FN); - if (auto F = getFunctionFromCall(CI)) { -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesInaccessibleMemory(); - F->setOnlyReadsMemory(); -#else - F->addFnAttr(Attribute::InaccessibleMemOnly); - F->addFnAttr(Attribute::ReadOnly); -#endif - } -#if LLVM_VERSION_MAJOR >= 16 - CI->setOnlyAccessesInaccessibleMemory(); - CI->setOnlyReadsMemory(); -#else - CI->addAttribute(AttributeList::FunctionIndex, - Attribute::InaccessibleMemOnly); - CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly); -#endif - return numThreads = CI; -} - -Value *GradientUtils::getOrInsertTotalMultiplicativeProduct(Value *val, - LoopContext &lc) { - // TODO optimize if val is invariant to loopContext - assert(val->getType()->isFPOrFPVectorTy()); - for (auto &I : *lc.header) { - if (auto PN = dyn_cast(&I)) { - if (PN->getType() != val->getType()) - continue; - Value *ival = PN->getIncomingValueForBlock(lc.preheader); - if (auto CDV = dyn_cast(ival)) { - if (CDV->isSplat()) - ival = CDV->getSplatValue(); - } - if (auto C = dyn_cast(ival)) { - if (!C->isExactlyValue(APFloat(C->getType()->getFltSemantics(), "1"))) { - continue; - } - } else - continue; - for (auto IB : PN->blocks()) { - if (IB == lc.preheader) - continue; - - if (auto BO = - dyn_cast(PN->getIncomingValueForBlock(IB))) { - if (BO->getOpcode() != BinaryOperator::FMul) - goto continueOutermost; - if (BO->getOperand(0) == PN && BO->getOperand(1) == val) - return BO; - if (BO->getOperand(1) == PN && BO->getOperand(0) == val) - return BO; - } else - goto continueOutermost; - } - } else - break; - continueOutermost:; - } - - IRBuilder<> lbuilder(lc.header, lc.header->begin()); - auto PN = lbuilder.CreatePHI(val->getType(), 2); - Constant *One = ConstantFP::get(val->getType()->getScalarType(), "1"); - if (VectorType *VTy = dyn_cast(val->getType())) { - One = ConstantVector::getSplat(VTy->getElementCount(), One); - } - PN->addIncoming(One, lc.preheader); - lbuilder.SetInsertPoint(lc.header->getFirstNonPHI()); - if (auto inst = dyn_cast(val)) { - if (DT.dominates(PN, inst)) - lbuilder.SetInsertPoint(inst->getNextNode()); - } - Value *red = lbuilder.CreateFMul(PN, val); - for (auto pred : predecessors(lc.header)) { - if (pred == lc.preheader) - continue; - PN->addIncoming(red, pred); - } - return red; -} - -Value *GradientUtils::getOrInsertConditionalIndex(Value *val, LoopContext &lc, - bool pickTrue) { - assert(val->getType()->isIntOrIntVectorTy(1)); - // TODO optimize if val is invariant to loopContext - for (auto &I : *lc.header) { - if (auto PN = dyn_cast(&I)) { - if (PN->getNumIncomingValues() == 0) - continue; - if (PN->getType() != lc.incvar->getType()) - continue; - Value *ival = PN->getIncomingValueForBlock(lc.preheader); - if (auto C = dyn_cast(ival)) { - if (!C->isNullValue()) { - continue; - } - } else - continue; - for (auto IB : PN->blocks()) { - if (IB == lc.preheader) - continue; - - if (auto SI = dyn_cast(PN->getIncomingValueForBlock(IB))) { - if (SI->getCondition() != val) - goto continueOutermost; - if (pickTrue && SI->getFalseValue() == PN) { - // TODO handle vector of - if (SI->getTrueValue() == lc.incvar) - return SI; - } - if (!pickTrue && SI->getTrueValue() == PN) { - // TODO handle vector of - if (SI->getFalseValue() == lc.incvar) - return SI; - } - } else - goto continueOutermost; - } - } else - break; - continueOutermost:; - } - - IRBuilder<> lbuilder(lc.header, lc.header->begin()); - auto PN = lbuilder.CreatePHI(lc.incvar->getType(), 2); - Constant *Zero = - Constant::getNullValue(lc.incvar->getType()->getScalarType()); - PN->addIncoming(Zero, lc.preheader); - lbuilder.SetInsertPoint(lc.incvar->getNextNode()); - Value *red = lc.incvar; - if (VectorType *VTy = dyn_cast(val->getType())) { -#if LLVM_VERSION_MAJOR >= 12 - red = lbuilder.CreateVectorSplat(VTy->getElementCount(), red); -#else - red = lbuilder.CreateVectorSplat(VTy->getNumElements(), red); -#endif - } - if (auto inst = dyn_cast(val)) { - if (DT.dominates(PN, inst)) - lbuilder.SetInsertPoint(inst->getNextNode()); - } - assert(red->getType() == PN->getType()); - red = lbuilder.CreateSelect(val, pickTrue ? red : PN, pickTrue ? PN : red); - for (auto pred : predecessors(lc.header)) { - if (pred == lc.preheader) - continue; - PN->addIncoming(red, pred); - } - return red; -} - -bool GradientUtils::assumeDynamicLoopOfSizeOne(Loop *L) const { - if (!EnzymeInactiveDynamic) - return false; - auto OL = OrigLI->getLoopFor(isOriginal(L->getHeader())); - assert(OL); - for (auto OB : OL->getBlocks()) { - for (auto &OI : *OB) { - if (!isConstantInstruction(&OI)) - return false; - if (auto SI = dyn_cast(&OI)) { - if (!isConstantValue(SI->getPointerOperand())) - return false; - } - if (auto MTI = dyn_cast(&OI)) { - if (!isConstantValue(MTI->getArgOperand(0))) - return false; - } - } - } - return true; -} - -DebugLoc GradientUtils::getNewFromOriginal(const DebugLoc L) const { - if (L.get() == nullptr) - return nullptr; - if (!oldFunc->getSubprogram()) - return L; - assert(originalToNewFn.hasMD()); - auto opt = originalToNewFn.getMappedMD(L.getAsMDNode()); - if (!opt) - return L; - assert(opt); - return DebugLoc(cast(*opt)); -} - -Value *GradientUtils::getNewFromOriginal(const Value *originst) const { - assert(originst); - if (isa(originst)) - return const_cast(originst); - auto f = originalToNewFn.find(originst); - if (f == originalToNewFn.end()) { - errs() << *oldFunc << "\n"; - errs() << *newFunc << "\n"; - dumpMap(originalToNewFn, [&](const Value *const &v) -> bool { - if (isa(originst)) - return isa(v); - if (isa(originst)) - return isa(v); - if (isa(originst)) - return isa(v); - if (isa(originst)) - return isa(v); - if (isa(originst)) - return isa(v); - return true; - }); - llvm::errs() << *originst << "\n"; - } - assert(f != originalToNewFn.end()); - if (f->second == nullptr) { - errs() << *oldFunc << "\n"; - errs() << *newFunc << "\n"; - errs() << *originst << "\n"; - } - assert(f->second); - return f->second; -} - -Instruction * -GradientUtils::getNewFromOriginal(const Instruction *newinst) const { - auto ninst = getNewFromOriginal((Value *)newinst); - if (!isa(ninst)) { - errs() << *oldFunc << "\n"; - errs() << *newFunc << "\n"; - errs() << *ninst << " - " << *newinst << "\n"; - } - return cast(ninst); -} - -BasicBlock *GradientUtils::getNewFromOriginal(const BasicBlock *newinst) const { - return cast(getNewFromOriginal((Value *)newinst)); -} - -Value *GradientUtils::hasUninverted(const Value *inverted) const { - for (auto v : invertedPointers) { - if (v.second == inverted) - return const_cast(v.first); - } - return nullptr; -} - -BasicBlock *GradientUtils::getOriginalFromNew(const BasicBlock *newinst) const { - assert(newinst->getParent() == newFunc); - auto found = newToOriginalFn.find(newinst); - assert(found != newToOriginalFn.end()); - Value *res = found->second; - return cast(res); -} - -Value *GradientUtils::isOriginal(const Value *newinst) const { - if (isa(newinst) || isa(newinst)) - return const_cast(newinst); -#ifndef NDEBUG - if (auto arg = dyn_cast(newinst)) { - assert(arg->getParent() == newFunc); - } - if (auto inst = dyn_cast(newinst)) { - assert(inst->getParent()->getParent() == newFunc); - } -#endif - auto found = newToOriginalFn.find(newinst); - if (found == newToOriginalFn.end()) - return nullptr; - return found->second; -} - -Instruction *GradientUtils::isOriginal(const Instruction *newinst) const { - return cast_or_null(isOriginal((const Value *)newinst)); -} - -BasicBlock *GradientUtils::isOriginal(const BasicBlock *newinst) const { - return cast_or_null(isOriginal((const Value *)newinst)); -} - -Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM, - const ValueToValueMapTy &available, - UnwrapMode unwrapMode, BasicBlock *scope, - bool permitCache) { - assert(val); - assert(val->getName() != ""); - assert(val->getType()); - - for (auto pair : available) { - assert(pair.first); - assert(pair.first->getType()); - if (pair.second) { - assert(pair.second->getType()); - assert(pair.first->getType() == pair.second->getType()); - } - } - - if (isa(val) && - cast(val)->getMetadata("enzyme_mustcache")) { - return val; - } - - if (available.count(val)) { - auto avail = available.lookup(val); - assert(avail->getType()); - if (avail->getType() != val->getType()) { - llvm::errs() << "val: " << *val << "\n"; - llvm::errs() << "available[val]: " << *available.lookup(val) << "\n"; - } - assert(available.lookup(val)->getType() == val->getType()); - return available.lookup(val); - } - - if (auto inst = dyn_cast(val)) { - if (inversionAllocs && inst->getParent() == inversionAllocs) { - return val; - } - // if (inst->getParent() == &newFunc->getEntryBlock()) { - // return inst; - //} - if (inst->getParent()->getParent() == newFunc && - isOriginalBlock(*BuilderM.GetInsertBlock())) { - if (BuilderM.GetInsertBlock()->size() && - BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) { - if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) { - // llvm::errs() << "allowed " << *inst << "from domination\n"; - assert(inst->getType() == val->getType()); - return inst; - } - } else { - if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) { - // llvm::errs() << "allowed " << *inst << "from block domination\n"; - assert(inst->getType() == val->getType()); - return inst; - } - } - } - assert(!TapesToPreventRecomputation.count(inst)); - } - - std::pair idx = std::make_pair(val, scope); - // assert(!startsWith(val->getName(), "$tapeload")); - if (permitCache) { - auto found0 = unwrap_cache.find(BuilderM.GetInsertBlock()); - if (found0 != unwrap_cache.end()) { - auto found1 = found0->second.find(idx.first); - if (found1 != found0->second.end()) { - auto found2 = found1->second.find(idx.second); - if (found2 != found1->second.end()) { - - auto cachedValue = found2->second; - if (cachedValue == nullptr) { - found1->second.erase(idx.second); - if (found1->second.size() == 0) { - found0->second.erase(idx.first); - } - } else { - if (cachedValue->getType() != val->getType()) { - llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << "val: " << *val << "\n"; - llvm::errs() << "unwrap_cache[cidx]: " << *cachedValue << "\n"; - } - assert(cachedValue->getType() == val->getType()); - return cachedValue; - } - } - } - } - } - - if (this->mode == DerivativeMode::ReverseModeGradient || - this->mode == DerivativeMode::ForwardModeSplit || - this->mode == DerivativeMode::ReverseModeCombined) - if (auto inst = dyn_cast(val)) { - if (inst->getParent()->getParent() == newFunc) { - if (unwrapMode == UnwrapMode::LegalFullUnwrap && - this->mode != DerivativeMode::ReverseModeCombined) { - // TODO this isOriginal is a bottleneck, the new mapping of - // knownRecompute should be precomputed and maintained to lookup - // instead - Instruction *orig = isOriginal(inst); - // If a given value has been chosen to be cached, do not compute the - // operands to unwrap it, instead simply emit a placeholder to be - // replaced by the cache load later. This placeholder should only be - // returned when the original value would be recomputed (e.g. this - // function would not return null). Since this case assumes everything - // can be recomputed, simply return the placeholder. - if (orig && knownRecomputeHeuristic.find(orig) != - knownRecomputeHeuristic.end()) { - if (!knownRecomputeHeuristic[orig]) { - assert(inst->getParent()->getParent() == newFunc); - auto placeholder = BuilderM.CreatePHI( - val->getType(), 0, val->getName() + "_krcLFUreplacement"); - unwrappedLoads[placeholder] = inst; - SmallVector avail; - for (auto pair : available) - if (pair.second) - avail.push_back(MDNode::get( - placeholder->getContext(), - {ValueAsMetadata::get(const_cast(pair.first)), - ValueAsMetadata::get(pair.second)})); - placeholder->setMetadata( - "enzyme_available", - MDNode::get(placeholder->getContext(), avail)); - if (!permitCache) - return placeholder; - return unwrap_cache[BuilderM.GetInsertBlock()][idx.first] - [idx.second] = placeholder; - } - } - } else if (unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { - // TODO this isOriginal is a bottleneck, the new mapping of - // knownRecompute should be precomputed and maintained to lookup - // instead - Instruction *orig = isOriginal(inst); - // If a given value has been chosen to be cached, do not compute the - // operands to unwrap it, instead simply emit a placeholder to be - // replaced by the cache load later. This placeholder should only be - // returned when the original value would be recomputed (e.g. this - // function would not return null). See note below about the condition - // as applied to this case. - if (orig) { - auto found = knownRecomputeHeuristic.find(orig); - if (found != knownRecomputeHeuristic.end()) { - if (!found->second) { - if (mode == DerivativeMode::ReverseModeCombined) { - // Don't unnecessarily cache a value if the caching - // heuristic says we should preserve this precise (and not - // an lcssa wrapped) value - if (!isOriginalBlock(*BuilderM.GetInsertBlock())) { - Value *nval = inst; - if (scope) - nval = fixLCSSA(inst, scope); - if (nval == inst) - goto endCheck; - } - } else { - // Note that this logic (original load must dominate or - // alternatively be in the reverse block) is only valid iff - // when applicable (here if in split mode), an overwritten - // load cannot be hoisted outside of a loop to be used as a - // loop limit. This optimization is currently done in the - // combined mode (e.g. if a load isn't modified between a - // prior insertion point and the actual load, it is legal to - // recompute). - if (!isOriginalBlock(*BuilderM.GetInsertBlock()) || - DT.dominates(inst, &*BuilderM.GetInsertPoint())) { - assert(inst->getParent()->getParent() == newFunc); - auto placeholder = BuilderM.CreatePHI( - val->getType(), 0, - val->getName() + "_krcAFUWLreplacement"); - unwrappedLoads[placeholder] = inst; - SmallVector avail; - for (auto pair : available) - if (pair.second) - avail.push_back( - MDNode::get(placeholder->getContext(), - {ValueAsMetadata::get( - const_cast(pair.first)), - ValueAsMetadata::get(pair.second)})); - placeholder->setMetadata( - "enzyme_available", - MDNode::get(placeholder->getContext(), avail)); - if (!permitCache) - return placeholder; - return unwrap_cache[BuilderM.GetInsertBlock()][idx.first] - [idx.second] = placeholder; - } - } - } - } - } - } else if (unwrapMode != UnwrapMode::LegalFullUnwrapNoTapeReplace && - mode != DerivativeMode::ReverseModeCombined) { - // TODO this isOriginal is a bottleneck, the new mapping of - // knownRecompute should be precomputed and maintained to lookup - // instead - - // If a given value has been chosen to be cached, do not compute the - // operands to unwrap it if it is not legal to do so. This prevents - // the creation of unused versions of the instruction's operand, which - // may be assumed to never be used and thus cause an error when they - // are inadvertantly cached. - Value *orig = isOriginal(val); - if (orig && knownRecomputeHeuristic.find(orig) != - knownRecomputeHeuristic.end()) { - if (!knownRecomputeHeuristic[orig]) { - return nullptr; - } - } - } - } - } - -#define getOpFullest(Builder, vtmp, frominst, lookupInst, check) \ - ({ \ - Value *v = vtmp; \ - BasicBlock *origParent = frominst; \ - Value *___res; \ - if (unwrapMode == UnwrapMode::LegalFullUnwrap || \ - unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace || \ - unwrapMode == UnwrapMode::AttemptFullUnwrap || \ - unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { \ - if (v == val) \ - ___res = nullptr; \ - else \ - ___res = unwrapM(v, Builder, available, unwrapMode, origParent, \ - permitCache); \ - if (!___res && unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { \ - bool noLookup = false; \ - auto found = available.find(v); \ - if (found != available.end() && !found->second) \ - noLookup = true; \ - if (auto opinst = dyn_cast(v)) \ - if (isOriginalBlock(*Builder.GetInsertBlock())) { \ - if (!DT.dominates(opinst, &*Builder.GetInsertPoint())) \ - noLookup = true; \ - } \ - origParent = lookupInst; \ - if (!noLookup) \ - ___res = lookupM(v, Builder, available, v != val, origParent); \ - } \ - if (___res) \ - assert(___res->getType() == v->getType() && "uw"); \ - } else { \ - origParent = lookupInst; \ - assert(unwrapMode == UnwrapMode::AttemptSingleUnwrap); \ - auto found = available.find(v); \ - if (found != available.end() && !found->second) \ - ___res = nullptr; \ - else { \ - ___res = lookupM(v, Builder, available, v != val, origParent); \ - if (___res && ___res->getType() != v->getType()) { \ - llvm::errs() << *newFunc << "\n"; \ - llvm::errs() << " v = " << *v << " res = " << *___res << "\n"; \ - } \ - if (___res) \ - assert(___res->getType() == v->getType() && "lu"); \ - } \ - } \ - ___res; \ - }) -#define getOpFull(Builder, vtmp, frominst) \ - ({ \ - BasicBlock *parent = scope; \ - if (parent == nullptr) \ - if (auto originst = dyn_cast(val)) \ - parent = originst->getParent(); \ - getOpFullest(Builder, vtmp, frominst, parent, true); \ - }) -#define getOpUnchecked(vtmp) \ - ({ \ - BasicBlock *parent = scope; \ - getOpFullest(BuilderM, vtmp, parent, parent, false); \ - }) -#define getOp(vtmp) \ - ({ \ - BasicBlock *parent = scope; \ - if (parent == nullptr) \ - if (auto originst = dyn_cast(val)) \ - parent = originst->getParent(); \ - getOpFullest(BuilderM, vtmp, parent, parent, true); \ - }) - - if (isa(val) || isa(val)) { - return val; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto toreturn = BuilderM.CreateFreeze(op0, op->getName() + "_unwrap"); - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - } - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto toreturn = BuilderM.CreateCast(op->getOpcode(), op0, op->getDestTy(), - op->getName() + "_unwrap"); - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != op->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getAggregateOperand()); - if (op0 == nullptr) - goto endCheck; - auto toreturn = BuilderM.CreateExtractValue(op0, op->getIndices(), - op->getName() + "_unwrap"); - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != op->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - // Unwrapped Aggregate, Indices, parent - SmallVector, InsertValueInst *>, 1> - insertElements; - - Value *agg = op; - while (auto op1 = dyn_cast(agg)) { - if (Value *orig = isOriginal(op1)) { - if (knownRecomputeHeuristic.count(orig)) { - if (!knownRecomputeHeuristic[orig]) { - break; - } - } - } - Value *valOp = op1->getInsertedValueOperand(); - valOp = getOp(valOp); - if (valOp == nullptr) - goto endCheck; - insertElements.push_back({valOp, op1->getIndices(), op1}); - agg = op1->getAggregateOperand(); - } - - Value *toreturn = getOp(agg); - if (toreturn == nullptr) - goto endCheck; - for (auto &&[valOp, idcs, parent] : reverse(insertElements)) { - toreturn = BuilderM.CreateInsertValue(toreturn, valOp, idcs, - parent->getName() + "_unwrap"); - - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][parent][idx.second] = toreturn; - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(parent); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != parent->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - } - - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto op1 = getOp(op->getOperand(1)); - if (op1 == nullptr) - goto endCheck; - auto toreturn = - BuilderM.CreateExtractElement(op0, op1, op->getName() + "_unwrap"); - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != op->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto op1 = getOp(op->getOperand(1)); - if (op1 == nullptr) - goto endCheck; - auto op2 = getOp(op->getOperand(2)); - if (op2 == nullptr) - goto endCheck; - auto toreturn = - BuilderM.CreateInsertElement(op0, op1, op2, op->getName() + "_unwrap"); - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != op->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto op1 = getOp(op->getOperand(1)); - if (op1 == nullptr) - goto endCheck; - auto toreturn = BuilderM.CreateShuffleVector( - op0, op1, op->getShuffleMaskForBitcode(), op->getName() + "'_unwrap"); - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != op->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto op1 = getOp(op->getOperand(1)); - if (op1 == nullptr) - goto endCheck; - if (op0->getType() != op1->getType()) { - llvm::errs() << " op: " << *op << " op0: " << *op0 << " op1: " << *op1 - << " p0: " << *op->getOperand(0) - << " p1: " << *op->getOperand(1) << "\n"; - } - assert(op0->getType() == op1->getType()); - auto toreturn = BuilderM.CreateBinOp(op->getOpcode(), op0, op1, - op->getName() + "_unwrap"); - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != op->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto op1 = getOp(op->getOperand(1)); - if (op1 == nullptr) - goto endCheck; - auto toreturn = BuilderM.CreateICmp(op->getPredicate(), op0, op1, - op->getName() + "_unwrap"); - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != op->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto op1 = getOp(op->getOperand(1)); - if (op1 == nullptr) - goto endCheck; - auto toreturn = BuilderM.CreateFCmp(op->getPredicate(), op0, op1, - op->getName() + "_unwrap"); - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != op->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (isa(val) && - cast(val)->getOpcode() == Instruction::FNeg) { - auto op = cast(val); - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto toreturn = BuilderM.CreateFNeg(op0, op->getName() + "_unwrap"); - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != - cast(val)->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - auto op0 = getOp(op->getOperand(0)); - if (op0 == nullptr) - goto endCheck; - auto op1 = getOp(op->getOperand(1)); - if (op1 == nullptr) - goto endCheck; - auto op2 = getOp(op->getOperand(2)); - if (op2 == nullptr) - goto endCheck; - auto toreturn = - BuilderM.CreateSelect(op0, op1, op2, op->getName() + "_unwrap"); - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(op); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != op->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto inst = dyn_cast(val)) { - auto ptr = getOp(inst->getPointerOperand()); - if (ptr == nullptr) - goto endCheck; - SmallVector ind; - // llvm::errs() << "inst: " << *inst << "\n"; - for (unsigned i = 0; i < inst->getNumIndices(); ++i) { - Value *a = inst->getOperand(1 + i); - auto op = getOp(a); - if (op == nullptr) - goto endCheck; - ind.push_back(op); - } - auto toreturn = BuilderM.CreateGEP(inst->getSourceElementType(), ptr, ind, - inst->getName() + "_unwrap"); - if (isa(toreturn)) - cast(toreturn)->setIsInBounds(inst->isInBounds()); - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(inst); - unwrappedLoads[newi] = val; - if (newi->getParent()->getParent() != inst->getParent()->getParent()) - newi->setDebugLoc(nullptr); - } - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto load = dyn_cast(val)) { - if (load->getMetadata("enzyme_noneedunwrap")) - return load; - - bool legalMove = unwrapMode == UnwrapMode::LegalFullUnwrap || - unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace; - if (!legalMove) { - BasicBlock *parent = nullptr; - if (isOriginalBlock(*BuilderM.GetInsertBlock())) - parent = BuilderM.GetInsertBlock(); - if (!parent || - LI.getLoopFor(parent) == LI.getLoopFor(load->getParent()) || - DT.dominates(load, parent)) { - legalMove = legalRecompute(load, available, &BuilderM); - } else { - legalMove = - legalRecompute(load, available, &BuilderM, /*reverse*/ false, - /*legalRecomputeCache*/ false); - } - } - if (!legalMove) { - auto &warnMap = UnwrappedWarnings[load]; - if (!warnMap.count(BuilderM.GetInsertBlock())) { - EmitWarning("UncacheableUnwrap", *load, "Load cannot be unwrapped ", - *load, " in ", BuilderM.GetInsertBlock()->getName(), " - ", - BuilderM.GetInsertBlock()->getParent()->getName(), " mode ", - unwrapMode); - warnMap.insert(BuilderM.GetInsertBlock()); - } - goto endCheck; - } - - Value *pidx = getOp(load->getOperand(0)); - - if (pidx == nullptr) { - goto endCheck; - } - - if (pidx->getType() != load->getOperand(0)->getType()) { - llvm::errs() << "load: " << *load << "\n"; - llvm::errs() << "load->getOperand(0): " << *load->getOperand(0) << "\n"; - llvm::errs() << "idx: " << *pidx << " unwrapping: " << *val - << " mode=" << unwrapMode << "\n"; - } - assert(pidx->getType() == load->getOperand(0)->getType()); - - auto toreturn = - BuilderM.CreateLoad(load->getType(), pidx, load->getName() + "_unwrap"); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - ToCopy2.push_back(LLVMContext::MD_alias_scope); - toreturn->copyMetadata(*load, ToCopy2); - toreturn->copyIRFlags(load); - if (load->getParent()->getParent() == newFunc) - if (auto orig = isOriginal(load)) { - SmallVector scopeMD = { - getDerivativeAliasScope(orig->getOperand(0), -1)}; - if (auto prev = orig->getMetadata(LLVMContext::MD_alias_scope)) { - for (auto &M : cast(prev)->operands()) { - scopeMD.push_back(M); - } - } - auto scope = MDNode::get(orig->getContext(), scopeMD); - toreturn->setMetadata(LLVMContext::MD_alias_scope, scope); - - SmallVector MDs; - for (size_t j = 0; j < getWidth(); j++) { - MDs.push_back(getDerivativeAliasScope(orig->getOperand(0), j)); - } - if (auto prev = orig->getMetadata(LLVMContext::MD_noalias)) { - for (auto &M : cast(prev)->operands()) { - MDs.push_back(M); - } - } - auto noscope = MDNode::get(orig->getContext(), MDs); - toreturn->setMetadata(LLVMContext::MD_noalias, noscope); - } - unwrappedLoads[toreturn] = load; - if (toreturn->getParent()->getParent() != load->getParent()->getParent()) - toreturn->setDebugLoc(nullptr); - else - toreturn->setDebugLoc(getNewFromOriginal(load->getDebugLoc())); - toreturn->setAlignment(load->getAlign()); - toreturn->setVolatile(load->isVolatile()); - toreturn->setOrdering(load->getOrdering()); - toreturn->setSyncScopeID(load->getSyncScopeID()); - if (toreturn->getParent()->getParent() != load->getParent()->getParent()) - toreturn->setDebugLoc(nullptr); - else - toreturn->setDebugLoc(getNewFromOriginal(load->getDebugLoc())); - toreturn->setMetadata(LLVMContext::MD_tbaa, - load->getMetadata(LLVMContext::MD_tbaa)); - auto invar_group = load->getMetadata(LLVMContext::MD_invariant_group); - if (!invar_group) { - bool legal = true; - if (load->getParent()->getParent() != newFunc) - legal = false; - else if (auto norig = isOriginal(load)) - for (const auto &pair : rematerializableAllocations) { - for (auto V : pair.second.loads) - if (V == norig) { - legal = false; - break; - } - if (!legal) - break; - } - if (legal) { - invar_group = MDNode::getDistinct(load->getContext(), {}); - load->setMetadata(LLVMContext::MD_invariant_group, invar_group); - } - } - toreturn->setMetadata(LLVMContext::MD_invariant_group, invar_group); - // TODO adding to cache only legal if no alias of any future writes - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } else if (auto op = dyn_cast(val)) { - - bool legalMove = unwrapMode == UnwrapMode::LegalFullUnwrap || - unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace; - if (!legalMove) { - legalMove = legalRecompute(op, available, &BuilderM); - } - if (!legalMove) - goto endCheck; - - SmallVector args; -#if LLVM_VERSION_MAJOR >= 14 - for (unsigned i = 0; i < op->arg_size(); ++i) -#else - for (unsigned i = 0; i < op->getNumArgOperands(); ++i) -#endif - { - args.push_back(getOp(op->getArgOperand(i))); - if (args[i] == nullptr) - goto endCheck; - } - - Value *fn = getOp(op->getCalledOperand()); - if (fn == nullptr) - goto endCheck; - - auto toreturn = - cast(BuilderM.CreateCall(op->getFunctionType(), fn, args)); - toreturn->copyIRFlags(op); - toreturn->setAttributes(op->getAttributes()); - toreturn->setCallingConv(op->getCallingConv()); - toreturn->setTailCallKind(op->getTailCallKind()); - if (toreturn->getParent()->getParent() == op->getParent()->getParent()) - toreturn->setDebugLoc(getNewFromOriginal(op->getDebugLoc())); - else - toreturn->setDebugLoc(nullptr); - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; - unwrappedLoads[toreturn] = val; - return toreturn; - } else if (auto phi = dyn_cast(val)) { - if (phi->getNumIncomingValues() == 0) { - // This is a placeholder shadow for a load, rather than falling - // back to the uncached variant, use the proper procedure for - // an inverted load - if (auto dli = dyn_cast_or_null(hasUninverted(phi))) { - // Almost identical code to unwrap load (replacing use of shadow - // where appropriate) - if (dli->getMetadata("enzyme_noneedunwrap")) - return dli; - - bool legalMove = unwrapMode == UnwrapMode::LegalFullUnwrap || - unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace; - if (!legalMove) { - // TODO actually consider whether this is legal to move to the new - // location, rather than recomputable anywhere - legalMove = legalRecompute(dli, available, &BuilderM); - } - if (!legalMove) { - auto &warnMap = UnwrappedWarnings[phi]; - if (!warnMap.count(BuilderM.GetInsertBlock())) { - EmitWarning("UncacheableUnwrap", *dli, - "Differential Load cannot be unwrapped ", *dli, " in ", - BuilderM.GetInsertBlock()->getName(), " mode ", - unwrapMode); - warnMap.insert(BuilderM.GetInsertBlock()); - } - return nullptr; - } - - Value *pidx = nullptr; - - if (isOriginalBlock(*BuilderM.GetInsertBlock())) { - pidx = invertPointerM(dli->getOperand(0), BuilderM); - } else { - pidx = lookupM(invertPointerM(dli->getOperand(0), BuilderM), BuilderM, - available); - } - - if (pidx == nullptr) - goto endCheck; - - if (pidx->getType() != getShadowType(dli->getOperand(0)->getType())) { - llvm::errs() << "dli: " << *dli << "\n"; - llvm::errs() << "dli->getOperand(0): " << *dli->getOperand(0) << "\n"; - llvm::errs() << "pidx: " << *pidx << "\n"; - } - assert(pidx->getType() == getShadowType(dli->getOperand(0)->getType())); - - size_t s_idx = 0; - Value *toreturn = applyChainRule( - dli->getType(), BuilderM, - [&](Value *pidx) { - auto toreturn = BuilderM.CreateLoad(dli->getType(), pidx, - phi->getName() + "_unwrap"); - if (auto newi = dyn_cast(toreturn)) { - newi->copyIRFlags(dli); - unwrappedLoads[toreturn] = dli; - } - toreturn->setAlignment(dli->getAlign()); - toreturn->setVolatile(dli->isVolatile()); - toreturn->setOrdering(dli->getOrdering()); - toreturn->setSyncScopeID(dli->getSyncScopeID()); - llvm::SmallVector ToCopy2(MD_ToCopy); - toreturn->copyMetadata(*dli, ToCopy2); - SmallVector scopeMD = { - getDerivativeAliasScope(dli->getOperand(0), s_idx)}; - if (auto prev = dli->getMetadata(LLVMContext::MD_alias_scope)) { - for (auto &M : cast(prev)->operands()) { - scopeMD.push_back(M); - } - } - auto scope = MDNode::get(dli->getContext(), scopeMD); - toreturn->setMetadata(LLVMContext::MD_alias_scope, scope); - - SmallVector MDs; - for (ssize_t j = -1; j < getWidth(); j++) { - if (j != (ssize_t)s_idx) - MDs.push_back(getDerivativeAliasScope(dli->getOperand(0), j)); - } - if (auto prev = dli->getMetadata(LLVMContext::MD_noalias)) { - for (auto &M : cast(prev)->operands()) { - MDs.push_back(M); - } - } - if (MDs.size()) { - auto noscope = MDNode::get(dli->getContext(), MDs); - toreturn->setMetadata(LLVMContext::MD_noalias, noscope); - } - toreturn->setDebugLoc(getNewFromOriginal(dli->getDebugLoc())); - s_idx++; - return toreturn; - }, - pidx); - - // TODO adding to cache only legal if no alias of any future writes - if (permitCache) - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = - toreturn; - assert(val->getType() == toreturn->getType()); - return toreturn; - } - goto endCheck; - } - assert(phi->getNumIncomingValues() != 0); - - // If requesting loop bound and are requesting the total size. - // Rather than generating a new lcssa variable, use the existing loop exact - // bound var - BasicBlock *ivctx = scope; - if (!ivctx) - ivctx = BuilderM.GetInsertBlock(); - if (newFunc == ivctx->getParent() && !isOriginalBlock(*ivctx)) { - ivctx = originalForReverseBlock(*ivctx); - } - if ((ivctx == phi->getParent() || DT.dominates(phi, ivctx)) && - (!isOriginalBlock(*BuilderM.GetInsertBlock()) || - DT.dominates(phi, &*BuilderM.GetInsertPoint()))) { - LoopContext lc; - bool loopVar = false; - if (getContext(phi->getParent(), lc) && lc.var == phi) { - loopVar = true; - } else { - Value *V = nullptr; - bool legal = true; - for (auto &val : phi->incoming_values()) { - if (isa(val)) - continue; - if (V == nullptr) - V = val; - else if (V != val) { - legal = false; - break; - } - } - if (legal) { - if (auto I = dyn_cast_or_null(V)) { - if (getContext(I->getParent(), lc) && lc.var == I) { - loopVar = true; - } - } - } - } - if (loopVar) { - if (!lc.dynamic) { - Value *lim = getOp(lc.trueLimit); - if (lim) { - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = - lim; - return lim; - } - } else if (unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup && - reverseBlocks.size() > 0) { - // Must be in a reverse pass fashion for a lookup to index bound to be - // legal - assert(/*ReverseLimit*/ reverseBlocks.size() > 0); - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - lc.preheader); - Value *lim = lookupValueFromCache( - lc.var->getType(), - /*forwardPass*/ false, BuilderM, lctx, - getDynamicLoopLimit(LI.getLoopFor(lc.header)), - /*isi1*/ false, available); - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = lim; - return lim; - } - } - } - - auto parent = phi->getParent(); - - // Don't attempt to unroll a loop induction variable in other - // circumstances - auto &LLI = Logic.PPC.FAM.getResult(*parent->getParent()); - std::set prevIteration; - if (LLI.isLoopHeader(parent)) { - if (phi->getNumIncomingValues() != 2) { - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - auto L = LLI.getLoopFor(parent); - for (auto PH : predecessors(parent)) { - if (L->contains(PH)) - prevIteration.insert(PH); - } - if (prevIteration.size() && !legalRecompute(phi, available, &BuilderM)) { - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - } - for (auto &val : phi->incoming_values()) { - if (isPotentialLastLoopValue(val, parent, LLI)) { - if (unwrapMode == UnwrapMode::LegalFullUnwrap) { - llvm::errs() << " module: " << *newFunc->getParent() << "\n"; - llvm::errs() << " newFunc: " << *newFunc << "\n"; - llvm::errs() << " parent: " << *parent << "\n"; - llvm::errs() << " val: " << *val << "\n"; - } - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - } - - if (phi->getNumIncomingValues() == 1) { - assert(phi->getIncomingValue(0) != phi); - auto toreturn = getOpUnchecked(phi->getIncomingValue(0)); - if (toreturn == nullptr || toreturn == phi) { - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - assert(val->getType() == toreturn->getType()); - return toreturn; - } - - std::set targetToPreds; - // Map of function edges to list of values possible - std::map, - std::set> - done; - { - std::deque, - BasicBlock *>> - Q; // newblock, target - - for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i) { - Q.push_back( - std::make_pair(std::make_pair(phi->getIncomingBlock(i), parent), - phi->getIncomingBlock(i))); - targetToPreds.insert(phi->getIncomingBlock(i)); - } - - for (std::tuple< - std::pair, - BasicBlock *> - trace; - Q.size() > 0;) { - trace = Q.front(); - Q.pop_front(); - auto edge = std::get<0>(trace); - auto block = edge.first; - auto target = std::get<1>(trace); - - if (done[edge].count(target)) - continue; - done[edge].insert(target); - - if (DT.dominates(block, phi->getParent())) - continue; - - Loop *blockLoop = LI.getLoopFor(block); - - for (BasicBlock *Pred : predecessors(block)) { - // Don't go up the backedge as we can use the last value if desired - // via lcssa - if (blockLoop && blockLoop->getHeader() == block && - blockLoop == LI.getLoopFor(Pred)) - continue; - - Q.push_back( - std::tuple, BasicBlock *>( - std::make_pair(Pred, block), target)); - } - } - } - - std::set blocks; - for (auto pair : done) { - const auto &edge = pair.first; - blocks.insert(edge.first); - } - - BasicBlock *oldB = BuilderM.GetInsertBlock(); - if (BuilderM.GetInsertPoint() != oldB->end()) { - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - - BasicBlock *fwd = oldB; - bool inReverseBlocks = false; - if (!isOriginalBlock(*fwd)) { - auto found = reverseBlockToPrimal.find(oldB); - if (found == reverseBlockToPrimal.end()) { - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - fwd = found->second; - inReverseBlocks = - std::find(reverseBlocks[fwd].begin(), reverseBlocks[fwd].end(), - oldB) != reverseBlocks[fwd].end(); - } - - auto eraseBlocks = [&](ArrayRef blocks, BasicBlock *bret) { - SmallVector revtopo; - { - SmallPtrSet seen; - std::function dfs = [&](BasicBlock *B) { - if (seen.count(B)) - return; - seen.insert(B); - if (B->getTerminator()) - for (auto S : successors(B)) - if (!seen.count(S)) - dfs(S); - revtopo.push_back(B); - }; - for (auto B : blocks) - dfs(B); - if (!seen.count(bret)) - revtopo.insert(revtopo.begin(), bret); - } - - SmallVector toErase; - for (auto B : revtopo) { - if (B == bret) - continue; - for (auto &I : llvm::reverse(*B)) { - toErase.push_back(&I); - } - unwrap_cache.erase(B); - lookup_cache.erase(B); - if (reverseBlocks.size() > 0) { - auto tfwd = reverseBlockToPrimal[B]; - assert(tfwd); - auto rfound = reverseBlocks.find(tfwd); - assert(rfound != reverseBlocks.end()); - auto &tlst = rfound->second; - auto found = std::find(tlst.begin(), tlst.end(), B); - if (found != tlst.end()) - tlst.erase(found); - reverseBlockToPrimal.erase(B); - } - } - for (auto I : toErase) { - erase(I); - } - for (auto B : revtopo) - B->eraseFromParent(); - }; - - if (targetToPreds.size() == 3) { - for (auto block : blocks) { - if (!DT.dominates(block, phi->getParent())) - continue; - std::set foundtargets; - std::set uniqueTargets; - for (BasicBlock *succ : successors(block)) { - auto edge = std::make_pair(block, succ); - for (BasicBlock *target : done[edge]) { - if (foundtargets.find(target) != foundtargets.end()) { - goto rnextpair; - } - foundtargets.insert(target); - if (done[edge].size() == 1) - uniqueTargets.insert(target); - } - } - if (foundtargets.size() != 3) - goto rnextpair; - if (uniqueTargets.size() != 1) - goto rnextpair; - - { - BasicBlock *subblock = nullptr; - for (auto block2 : blocks) { - { - // The second split block must not have a parent with an edge - // to a block other than to itself, which can reach any of its - // two targets. - // TODO verify this - for (auto P : predecessors(block2)) { - for (auto S : successors(P)) { - if (S == block2) - continue; - auto edge = std::make_pair(P, S); - if (done.find(edge) != done.end()) { - for (auto target : done[edge]) { - if (foundtargets.find(target) != foundtargets.end() && - uniqueTargets.find(target) == uniqueTargets.end()) - goto nextblock; - } - } - } - } - std::set seen2; - for (BasicBlock *succ : successors(block2)) { - auto edge = std::make_pair(block2, succ); - if (done[edge].size() != 1) { - // llvm::errs() << " -- failed from noonesize\n"; - goto nextblock; - } - for (BasicBlock *target : done[edge]) { - if (seen2.find(target) != seen2.end()) { - // llvm::errs() << " -- failed from not uniqueTargets\n"; - goto nextblock; - } - seen2.insert(target); - if (foundtargets.find(target) == foundtargets.end()) { - // llvm::errs() << " -- failed from not unknown target\n"; - goto nextblock; - } - if (uniqueTargets.find(target) != uniqueTargets.end()) { - // llvm::errs() << " -- failed from not same target\n"; - goto nextblock; - } - } - } - if (seen2.size() != 2) { - // llvm::errs() << " -- failed from not 2 seen\n"; - goto nextblock; - } - subblock = block2; - break; - } - nextblock:; - } - - if (subblock == nullptr) - goto rnextpair; - - { - auto bi1 = cast(block->getTerminator()); - - auto cond1 = getOp(bi1->getCondition()); - if (cond1 == nullptr) { - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - auto bi2 = cast(subblock->getTerminator()); - auto cond2 = getOp(bi2->getCondition()); - if (cond2 == nullptr) { - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - - SmallVector predBlocks = {bi2->getSuccessor(0), - bi2->getSuccessor(1)}; - for (int i = 0; i < 2; i++) { - auto edge = std::make_pair(block, bi1->getSuccessor(i)); - if (done[edge].size() == 1) { - predBlocks.push_back(bi1->getSuccessor(i)); - } - } - - SmallVector vals; - - SmallVector blocks; - SmallVector endingBlocks; - - BasicBlock *last = oldB; - - BasicBlock *bret = BasicBlock::Create( - val->getContext(), oldB->getName() + "_phimerge", newFunc); - - for (size_t i = 0; i < predBlocks.size(); i++) { - BasicBlock *valparent = (i < 2) ? subblock : block; - assert(done.find(std::make_pair(valparent, predBlocks[i])) != - done.end()); - assert(done[std::make_pair(valparent, predBlocks[i])].size() == - 1); - blocks.push_back(BasicBlock::Create( - val->getContext(), oldB->getName() + "_phirc", newFunc)); - blocks[i]->moveAfter(last); - last = blocks[i]; - if (inReverseBlocks) - reverseBlocks[fwd].push_back(blocks[i]); - reverseBlockToPrimal[blocks[i]] = fwd; - IRBuilder<> B(blocks[i]); - - for (auto pair : unwrap_cache[oldB]) - unwrap_cache[blocks[i]].insert(pair); - for (auto pair : lookup_cache[oldB]) - lookup_cache[blocks[i]].insert(pair); - auto PB = *done[std::make_pair(valparent, predBlocks[i])].begin(); - - if (auto inst = dyn_cast( - phi->getIncomingValueForBlock(PB))) { - // Recompute the phi computation with the conditional if: - // 1) the instruction may read from memory AND does not - // dominate the current insertion point (thereby - // potentially making such recomputation without the - // condition illegal) - // 2) the value is a call or load and option is set to not - // speculatively recompute values within a phi - // OR - // 3) the value comes from a previous iteration. - BasicBlock *nextScope = PB; - // if (inst->getParent() == nextScope) nextScope = - // phi->getParent(); - if (prevIteration.count(PB)) { - assert(0 && "tri block prev iteration unhandled"); - } else if (!DT.dominates(inst->getParent(), phi->getParent()) || - (!EnzymeSpeculatePHIs && - (isa(inst) || isa(inst)))) - vals.push_back(getOpFull(B, inst, nextScope)); - else - vals.push_back(getOpFull(BuilderM, inst, nextScope)); - } else - vals.push_back( - getOpFull(BuilderM, phi->getIncomingValueForBlock(PB), PB)); - - if (!vals[i]) { - eraseBlocks(blocks, bret); - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - assert(val->getType() == vals[i]->getType()); - B.CreateBr(bret); - endingBlocks.push_back(B.GetInsertBlock()); - } - - bret->moveAfter(last); - - BasicBlock *bsplit = BasicBlock::Create( - val->getContext(), oldB->getName() + "_phisplt", newFunc); - bsplit->moveAfter(oldB); - if (inReverseBlocks) - reverseBlocks[fwd].push_back(bsplit); - reverseBlockToPrimal[bsplit] = fwd; - BuilderM.CreateCondBr( - cond1, - (done[std::make_pair(block, bi1->getSuccessor(0))].size() == 1) - ? blocks[2] - : bsplit, - (done[std::make_pair(block, bi1->getSuccessor(1))].size() == 1) - ? blocks[2] - : bsplit); - - BuilderM.SetInsertPoint(bsplit); - BuilderM.CreateCondBr(cond2, blocks[0], blocks[1]); - - BuilderM.SetInsertPoint(bret); - if (inReverseBlocks) - reverseBlocks[fwd].push_back(bret); - reverseBlockToPrimal[bret] = fwd; - auto toret = BuilderM.CreatePHI(val->getType(), vals.size()); - for (size_t i = 0; i < vals.size(); i++) - toret->addIncoming(vals[i], endingBlocks[i]); - assert(val->getType() == toret->getType()); - if (permitCache) { - unwrap_cache[bret][idx.first][idx.second] = toret; - } - unwrappedLoads[toret] = val; - for (auto pair : unwrap_cache[oldB]) - unwrap_cache[bret].insert(pair); - for (auto pair : lookup_cache[oldB]) - lookup_cache[bret].insert(pair); - return toret; - } - } - rnextpair:; - } - } - - Instruction *equivalentTerminator = nullptr; - - if (prevIteration.size() == 1) { - if (phi->getNumIncomingValues() == 2) { - - ValueToValueMapTy prevAvailable; - for (const auto &pair : available) - prevAvailable.insert(pair); - LoopContext ctx; - getContext(parent, ctx); - Value *prevIdx; - if (prevAvailable.count(ctx.var)) - prevIdx = prevAvailable[ctx.var]; - else { - if (!isOriginalBlock(*BuilderM.GetInsertBlock())) { - // If we are using the phi in the reverse pass of a block inside the - // loop itself the previous index variable (aka the previous inc) is - // equivalent to the current load of antivaralloc - if (LI.getLoopFor(ctx.header)->contains(fwd)) { - prevIdx = - BuilderM.CreateLoad(ctx.var->getType(), ctx.antivaralloc); - } else { - // However, if we are using the phi of the reverse pass of a block - // outside the loop we must be in the reverse pass of a block - // after the loop. In which case, the previous index variable (aka - // previous inc) is the total loop iteration count-1, aka the - // trueLimit. - Value *lim = nullptr; - if (ctx.dynamic) { - // Must be in a reverse pass fashion for a lookup to index bound - // to be legal - assert(/*ReverseLimit*/ reverseBlocks.size() > 0); - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - ctx.preheader); - lim = lookupValueFromCache( - ctx.var->getType(), - /*forwardPass*/ false, BuilderM, lctx, - getDynamicLoopLimit(LI.getLoopFor(ctx.header)), - /*isi1*/ false, /*available*/ prevAvailable); - } else { - lim = lookupM(ctx.trueLimit, BuilderM, prevAvailable); - } - prevIdx = lim; - } - } else { - prevIdx = ctx.var; - } - } - // Prevent recursive unroll. - prevAvailable[phi] = nullptr; - SmallVector vals; - - SmallVector blocks; - SmallVector endingBlocks; - BasicBlock *last = oldB; - - BasicBlock *bret = BasicBlock::Create( - val->getContext(), oldB->getName() + "_phimerge", newFunc); - - SmallVector preds(predecessors(phi->getParent())); - - for (auto tup : llvm::enumerate(preds)) { - auto i = tup.index(); - BasicBlock *PB = tup.value(); - blocks.push_back(BasicBlock::Create( - val->getContext(), oldB->getName() + "_phirc", newFunc)); - blocks[i]->moveAfter(last); - last = blocks[i]; - if (reverseBlocks.size() > 0) { - if (inReverseBlocks) - reverseBlocks[fwd].push_back(blocks[i]); - reverseBlockToPrimal[blocks[i]] = fwd; - } - IRBuilder<> B(blocks[i]); - - if (!prevIteration.count(PB)) { - for (auto pair : unwrap_cache[oldB]) - unwrap_cache[blocks[i]].insert(pair); - for (auto pair : lookup_cache[oldB]) - lookup_cache[blocks[i]].insert(pair); - } - - if (auto inst = - dyn_cast(phi->getIncomingValueForBlock(PB))) { - // Recompute the phi computation with the conditional if: - // 1) the instruction may read from memory AND does not dominate - // the current insertion point (thereby potentially making such - // recomputation without the condition illegal) - // 2) the value is a call or load and option is set to not - // speculatively recompute values within a phi - // OR - // 3) the value comes from a previous iteration. - BasicBlock *nextScope = PB; - // if (inst->getParent() == nextScope) nextScope = phi->getParent(); - if (prevIteration.count(PB)) { - prevAvailable[ctx.incvar] = prevIdx; - prevAvailable[ctx.var] = - B.CreateSub(prevIdx, ConstantInt::get(prevIdx->getType(), 1), - "", /*NUW*/ true, /*NSW*/ false); - Value *___res; - if (unwrapMode == UnwrapMode::LegalFullUnwrap || - unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace || - unwrapMode == UnwrapMode::AttemptFullUnwrap || - unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { - ___res = unwrapM(inst, B, prevAvailable, unwrapMode, nextScope, - /*permitCache*/ false); - if (!___res && - unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { - bool noLookup = false; - if (isOriginalBlock(*B.GetInsertBlock())) { - if (!DT.dominates(inst, &*B.GetInsertPoint())) - noLookup = true; - } - if (!noLookup) { - BasicBlock *nS2 = nextScope; - Value *v = inst; - ___res = lookupM(v, B, prevAvailable, v != val, nS2); - } - } - if (___res) - assert(___res->getType() == inst->getType() && "uw"); - } else { - BasicBlock *nS2 = nextScope; - Value *v = inst; - ___res = lookupM(v, B, prevAvailable, v != val, nS2); - if (___res && ___res->getType() != v->getType()) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << " v = " << *v << " res = " << *___res << "\n"; - } - if (___res) - assert(___res->getType() == inst->getType() && "lu"); - } - vals.push_back(___res); - } else if (!DT.dominates(inst->getParent(), phi->getParent()) || - (!EnzymeSpeculatePHIs && - (isa(inst) || isa(inst)))) - vals.push_back(getOpFull(B, inst, nextScope)); - else - vals.push_back(getOpFull(BuilderM, inst, nextScope)); - } else - vals.push_back(phi->getIncomingValueForBlock(PB)); - - if (!vals[i]) { - eraseBlocks(blocks, bret); - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - assert(val->getType() == vals[i]->getType()); - B.CreateBr(bret); - endingBlocks.push_back(B.GetInsertBlock()); - } - - // Coming from a previous iteration is equivalent to the current - // iteration at zero. - Value *cond; - if (prevIteration.count(preds[0])) - cond = BuilderM.CreateICmpNE(prevIdx, - ConstantInt::get(prevIdx->getType(), 0)); - else - cond = BuilderM.CreateICmpEQ(prevIdx, - ConstantInt::get(prevIdx->getType(), 0)); - - if (blocks[0]->size() == 1 && blocks[1]->size() == 1) { - if (auto B1 = dyn_cast(blocks[0]->getTerminator())) - if (auto B2 = dyn_cast(blocks[1]->getTerminator())) - if (B1->isUnconditional() && B2->isUnconditional() && - B1->getSuccessor(0) == bret && B2->getSuccessor(0) == bret) { - eraseBlocks(blocks, bret); - Value *toret = BuilderM.CreateSelect( - cond, vals[0], vals[1], phi->getName() + "_unwrap"); - if (permitCache) { - unwrap_cache[BuilderM.GetInsertBlock()][idx.first] - [idx.second] = toret; - } - if (auto instRet = dyn_cast(toret)) { - unwrappedLoads[instRet] = val; - } - return toret; - } - } - - bret->moveAfter(last); - BuilderM.CreateCondBr(cond, blocks[0], blocks[1]); - - BuilderM.SetInsertPoint(bret); - if (inReverseBlocks) - reverseBlocks[fwd].push_back(bret); - reverseBlockToPrimal[bret] = fwd; - auto toret = BuilderM.CreatePHI(val->getType(), vals.size()); - for (size_t i = 0; i < vals.size(); i++) - toret->addIncoming(vals[i], endingBlocks[i]); - assert(val->getType() == toret->getType()); - if (permitCache) { - unwrap_cache[bret][idx.first][idx.second] = toret; - } - for (auto pair : unwrap_cache[oldB]) - unwrap_cache[bret].insert(pair); - for (auto pair : lookup_cache[oldB]) - lookup_cache[bret].insert(pair); - unwrappedLoads[toret] = val; - return toret; - } - } - if (prevIteration.size() != 0) { - llvm::errs() << "prev iteration: " << *phi << "\n"; - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - - for (auto block : blocks) { - if (!DT.dominates(block, phi->getParent())) - continue; - std::set foundtargets; - for (BasicBlock *succ : successors(block)) { - auto edge = std::make_pair(block, succ); - if (done[edge].size() != 1) { - goto nextpair; - } - BasicBlock *target = *done[edge].begin(); - if (foundtargets.find(target) != foundtargets.end()) { - goto nextpair; - } - foundtargets.insert(target); - } - if (foundtargets.size() != targetToPreds.size()) { - goto nextpair; - } - - if (DT.dominates(block, parent)) { - equivalentTerminator = block->getTerminator(); - goto fast; - } - nextpair:; - } - goto endCheck; - - fast:; - assert(equivalentTerminator); - - if (isa(equivalentTerminator) || - isa(equivalentTerminator)) { - BasicBlock *oldB = BuilderM.GetInsertBlock(); - - SmallVector predBlocks; - Value *cond = nullptr; - if (auto branch = dyn_cast(equivalentTerminator)) { - cond = branch->getCondition(); - predBlocks.push_back(branch->getSuccessor(0)); - predBlocks.push_back(branch->getSuccessor(1)); - } else { - auto SI = cast(equivalentTerminator); - cond = SI->getCondition(); - predBlocks.push_back(SI->getDefaultDest()); - for (auto scase : SI->cases()) { - predBlocks.push_back(scase.getCaseSuccessor()); - } - } - cond = getOp(cond); - if (!cond) { - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - - SmallVector vals; - - SmallVector blocks; - SmallVector endingBlocks; - - BasicBlock *last = oldB; - - assert(prevIteration.size() == 0); - - BasicBlock *bret = BasicBlock::Create( - val->getContext(), oldB->getName() + "_phimerge", newFunc); - - for (size_t i = 0; i < predBlocks.size(); i++) { - assert(done.find(std::make_pair(equivalentTerminator->getParent(), - predBlocks[i])) != done.end()); - assert(done[std::make_pair(equivalentTerminator->getParent(), - predBlocks[i])] - .size() == 1); - BasicBlock *PB = *done[std::make_pair(equivalentTerminator->getParent(), - predBlocks[i])] - .begin(); - blocks.push_back(BasicBlock::Create( - val->getContext(), oldB->getName() + "_phirc", newFunc)); - blocks[i]->moveAfter(last); - last = blocks[i]; - if (reverseBlocks.size() > 0) { - if (inReverseBlocks) - reverseBlocks[fwd].push_back(blocks[i]); - reverseBlockToPrimal[blocks[i]] = fwd; - } - IRBuilder<> B(blocks[i]); - - for (auto pair : unwrap_cache[oldB]) - unwrap_cache[blocks[i]].insert(pair); - for (auto pair : lookup_cache[oldB]) - lookup_cache[blocks[i]].insert(pair); - - if (auto inst = - dyn_cast(phi->getIncomingValueForBlock(PB))) { - // Recompute the phi computation with the conditional if: - // 1) the instruction may reat from memory AND does not dominate - // the current insertion point (thereby potentially making such - // recomputation without the condition illegal) - // 2) the value is a call or load and option is set to not - // speculatively recompute values within a phi - // OR - // 3) the value comes from a previous iteration. - BasicBlock *nextScope = PB; - // if (inst->getParent() == nextScope) nextScope = phi->getParent(); - if (!DT.dominates(inst->getParent(), phi->getParent()) || - (!EnzymeSpeculatePHIs && - (isa(inst) || isa(inst)))) - vals.push_back(getOpFull(B, inst, nextScope)); - else - vals.push_back(getOpFull(BuilderM, inst, nextScope)); - } else - vals.push_back(phi->getIncomingValueForBlock(PB)); - - if (!vals[i]) { - eraseBlocks(blocks, bret); - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - assert(val->getType() == vals[i]->getType()); - B.CreateBr(bret); - endingBlocks.push_back(B.GetInsertBlock()); - } - - // Fast path to not make a split block if no additional instructions - // were made in the two blocks - if (isa(equivalentTerminator) && blocks[0]->size() == 1 && - blocks[1]->size() == 1) { - if (auto B1 = dyn_cast(blocks[0]->getTerminator())) - if (auto B2 = dyn_cast(blocks[1]->getTerminator())) - if (B1->isUnconditional() && B2->isUnconditional() && - B1->getSuccessor(0) == bret && B2->getSuccessor(0) == bret) { - eraseBlocks(blocks, bret); - Value *toret = BuilderM.CreateSelect(cond, vals[0], vals[1], - phi->getName() + "_unwrap"); - if (permitCache) { - unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = - toret; - } - if (auto instRet = dyn_cast(toret)) { - unwrappedLoads[instRet] = val; - } - return toret; - } - } - - if (BuilderM.GetInsertPoint() != oldB->end()) { - eraseBlocks(blocks, bret); - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - - bret->moveAfter(last); - if (isa(equivalentTerminator)) { - BuilderM.CreateCondBr(cond, blocks[0], blocks[1]); - } else { - auto SI = cast(equivalentTerminator); - auto NSI = BuilderM.CreateSwitch(cond, blocks[0], SI->getNumCases()); - size_t idx = 1; - for (auto scase : SI->cases()) { - NSI->addCase(scase.getCaseValue(), blocks[idx]); - idx++; - } - } - BuilderM.SetInsertPoint(bret); - if (inReverseBlocks) - reverseBlocks[fwd].push_back(bret); - reverseBlockToPrimal[bret] = fwd; - auto toret = BuilderM.CreatePHI(val->getType(), vals.size()); - for (size_t i = 0; i < vals.size(); i++) - toret->addIncoming(vals[i], endingBlocks[i]); - assert(val->getType() == toret->getType()); - if (permitCache) { - unwrap_cache[bret][idx.first][idx.second] = toret; - } - for (auto pair : unwrap_cache[oldB]) - unwrap_cache[bret].insert(pair); - for (auto pair : lookup_cache[oldB]) - lookup_cache[bret].insert(pair); - unwrappedLoads[toret] = val; - return toret; - } - assert(unwrapMode != UnwrapMode::LegalFullUnwrap); - goto endCheck; - } - -endCheck: - assert(val); - if (unwrapMode == UnwrapMode::LegalFullUnwrap || - unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace || - unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { - assert(val->getName() != ""); - Value *nval = val; - if (auto opinst = dyn_cast(nval)) - if (isOriginalBlock(*BuilderM.GetInsertBlock())) { - if (!DT.dominates(opinst, &*BuilderM.GetInsertPoint())) { - if (unwrapMode != UnwrapMode::AttemptFullUnwrapWithLookup) { - llvm::errs() << " oldF: " << *oldFunc << "\n"; - llvm::errs() << " opParen: " << *opinst->getParent()->getParent() - << "\n"; - llvm::errs() << " newF: " << *newFunc << "\n"; - llvm::errs() << " - blk: " << *BuilderM.GetInsertBlock(); - llvm::errs() << " opInst: " << *opinst << " mode=" << unwrapMode - << "\n"; - } - assert(unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup); - return nullptr; - } - } - auto toreturn = lookupM(nval, BuilderM, available, - /*tryLegalRecomputeCheck*/ false, scope); - assert(val->getType() == toreturn->getType()); - return toreturn; - } - - if (auto inst = dyn_cast(val)) { - if (isOriginalBlock(*BuilderM.GetInsertBlock())) { - if (BuilderM.GetInsertBlock()->size() && - BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) { - if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) { - assert(inst->getType() == val->getType()); - return inst; - } - } else { - if (DT.dominates(inst, BuilderM.GetInsertBlock())) { - assert(inst->getType() == val->getType()); - return inst; - } - } - } - assert(val->getName() != ""); - auto &warnMap = UnwrappedWarnings[inst]; - if (!warnMap.count(BuilderM.GetInsertBlock())) { - EmitWarning("NoUnwrap", *inst, "Cannot unwrap ", *val, " in ", - BuilderM.GetInsertBlock()->getName()); - warnMap.insert(BuilderM.GetInsertBlock()); - } - } - return nullptr; -} - -void GradientUtils::ensureLookupCached(Instruction *inst, bool shouldFree, - BasicBlock *scope, MDNode *TBAA) { - assert(inst); - if (scopeMap.find(inst) != scopeMap.end()) - return; - if (shouldFree) - assert(reverseBlocks.size()); - - if (scope == nullptr) - scope = inst->getParent(); - - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, scope); - - AllocaInst *cache = - createCacheForScope(lctx, inst->getType(), inst->getName(), shouldFree); - assert(cache); - Value *Val = inst; - insert_or_assign( - scopeMap, Val, - std::pair, LimitContext>(cache, lctx)); - storeInstructionInCache(lctx, inst, cache, TBAA); -} - -Value *GradientUtils::fixLCSSA(Instruction *inst, BasicBlock *forwardBlock, - bool legalInBlock) { - assert(inst->getName() != ""); - - if (auto lcssaPHI = dyn_cast(inst)) { - auto found = lcssaPHIToOrig.find(lcssaPHI); - if (found != lcssaPHIToOrig.end()) - inst = cast(found->second); - } - - if (inst->getParent() == inversionAllocs) - return inst; - - if (!isOriginalBlock(*forwardBlock)) { - forwardBlock = originalForReverseBlock(*forwardBlock); - } - - bool containsLastLoopValue = isPotentialLastLoopValue(inst, forwardBlock, LI); - - // If the instruction cannot represent a loop value, return the original - // instruction if it either is guaranteed to be available within the block, - // or it is not needed to guaranteed availability. - if (!containsLastLoopValue) { - if (!legalInBlock) - return inst; - if (forwardBlock == inst->getParent() || DT.dominates(inst, forwardBlock)) - return inst; - } - - // llvm::errs() << " inst: " << *inst << "\n"; - // llvm::errs() << " seen: " << *inst->getParent() << "\n"; - assert(inst->getParent() != inversionAllocs); - assert(isOriginalBlock(*inst->getParent())); - - if (lcssaFixes.find(inst) == lcssaFixes.end()) { - lcssaFixes[inst][inst->getParent()] = inst; - SmallPtrSet seen; - std::deque todo = {inst->getParent()}; - while (todo.size()) { - BasicBlock *cur = todo.front(); - todo.pop_front(); - if (seen.count(cur)) - continue; - seen.insert(cur); - for (auto Succ : successors(cur)) { - todo.push_back(Succ); - } - } - for (auto &BB : *inst->getParent()->getParent()) { - if (!seen.count(&BB) || - (inst->getParent() != &BB && DT.dominates(&BB, inst->getParent()))) { - // OrigPDT.dominates(isOriginal(inst->getParent()), - // isOriginal(&BB)))) { - lcssaFixes[inst][&BB] = UndefValue::get(inst->getType()); - } - } - } - - if (lcssaFixes[inst].find(forwardBlock) != lcssaFixes[inst].end()) { - return lcssaFixes[inst][forwardBlock]; - } - - // TODO replace forwardBlock with the first block dominated by inst, - // that dominates (or is) forwardBlock to ensuring maximum reuse - IRBuilder<> lcssa(&forwardBlock->front()); - auto lcssaPHI = - lcssa.CreatePHI(inst->getType(), 1, inst->getName() + "!manual_lcssa"); - lcssaFixes[inst][forwardBlock] = lcssaPHI; - lcssaPHIToOrig[lcssaPHI] = inst; - for (auto pred : predecessors(forwardBlock)) { - Value *val = nullptr; - if (inst->getParent() == pred || DT.dominates(inst, pred)) { - val = inst; - } - if (val == nullptr) { - val = fixLCSSA(inst, pred, /*legalInBlock*/ true); - assert(val->getType() == inst->getType()); - } - assert(val->getType() == inst->getType()); - lcssaPHI->addIncoming(val, pred); - } - - SmallPtrSet vals; - SmallVector todo(lcssaPHI->incoming_values().begin(), - lcssaPHI->incoming_values().end()); - while (todo.size()) { - Value *v = todo.back(); - todo.pop_back(); - if (v == lcssaPHI) - continue; - vals.insert(v); - } - assert(vals.size() > 0); - - if (vals.size() > 1) { - todo.append(vals.begin(), vals.end()); - vals.clear(); - while (todo.size()) { - Value *v = todo.back(); - todo.pop_back(); - - if (auto PN = dyn_cast(v)) - if (lcssaPHIToOrig.find(PN) != lcssaPHIToOrig.end()) { - v = lcssaPHIToOrig[PN]; - } - vals.insert(v); - } - } - assert(vals.size() > 0); - Value *val = nullptr; - if (vals.size() == 1) - val = *vals.begin(); - - if (val && (!legalInBlock || !isa(val) || - DT.dominates(cast(val), lcssaPHI))) { - - if (!isPotentialLastLoopValue(val, forwardBlock, LI)) { - bool nonSelfUse = false; - for (auto u : lcssaPHI->users()) { - if (u != lcssaPHI) { - nonSelfUse = true; - break; - } - } - if (!nonSelfUse) { - lcssaFixes[inst].erase(forwardBlock); - while (lcssaPHI->getNumOperands()) - lcssaPHI->removeIncomingValue(lcssaPHI->getNumOperands() - 1, false); - lcssaPHIToOrig.erase(lcssaPHI); - lcssaPHI->eraseFromParent(); - } - return val; - } - } - return lcssaPHI; -} - -Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, - int idx, bool replace) { - assert(malloc); - assert(BuilderQ.GetInsertBlock()->getParent() == newFunc); - assert(isOriginalBlock(*BuilderQ.GetInsertBlock())); - assert(!hasNoCache(malloc)); - if (mode == DerivativeMode::ReverseModeCombined) { - assert(!tape); - return malloc; - } - -#ifndef NDEBUG - if (auto CI = dyn_cast(malloc)) { - if (auto F = CI->getCalledFunction()) { - assert(F->getName() != "omp_get_thread_num"); - } - } -#endif - - if (malloc->getType()->isTokenTy()) { - llvm::errs() << " oldFunc: " << *oldFunc << "\n"; - llvm::errs() << " newFunc: " << *newFunc << "\n"; - llvm::errs() << " malloc: " << *malloc << "\n"; - } - assert(!malloc->getType()->isTokenTy()); - { - CountTrackedPointers T(malloc->getType()); - if (T.derived) { - llvm::errs() << " oldFunc: " << *oldFunc << "\n"; - llvm::errs() << " newFunc: " << *newFunc << "\n"; - llvm::errs() << " malloc: " << *malloc << "\n"; - } - assert(!T.derived); - } - - if (tape) { - if (idx == IndexMappingError) { - assert(malloc); - return UndefValue::get(malloc->getType()); - } - if (idx >= 0 && !tape->getType()->isStructTy()) { - llvm::errs() << "cacheForReverse incorrect tape type: " << *tape - << " idx: " << idx << "\n"; - } - assert(idx < 0 || tape->getType()->isStructTy()); - if (idx >= 0 && - (unsigned)idx >= cast(tape->getType())->getNumElements()) { - llvm::errs() << "oldFunc: " << *oldFunc << "\n"; - llvm::errs() << "newFunc: " << *newFunc << "\n"; - if (malloc) - llvm::errs() << "malloc: " << *malloc << "\n"; - llvm::errs() << "tape: " << *tape << "\n"; - llvm::errs() << "idx: " << idx << "\n"; - } - assert(idx < 0 || - (unsigned)idx < cast(tape->getType())->getNumElements()); - Value *ret = - (idx < 0) ? tape : BuilderQ.CreateExtractValue(tape, {(unsigned)idx}); - - if (ret->getType()->isEmptyTy()) { - if (auto inst = dyn_cast_or_null(malloc)) { - if (inst->getType() != ret->getType()) { - llvm::errs() << "oldFunc: " << *oldFunc << "\n"; - llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << "inst==malloc: " << *inst << "\n"; - llvm::errs() << "ret: " << *ret << "\n"; - } - assert(inst->getType() == ret->getType()); - if (replace) { - inst->replaceAllUsesWith(UndefValue::get(ret->getType())); - erase(inst); - } - } - Type *retType = ret->getType(); - if (replace) - if (auto ri = dyn_cast(ret)) - erase(ri); - return UndefValue::get(retType); - } - - LimitContext ctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - BuilderQ.GetInsertBlock()); - if (auto inst = dyn_cast(malloc)) - ctx = LimitContext(/*ReverseLimit*/ reverseBlocks.size() > 0, - inst->getParent()); - if (auto found = findInMap(scopeMap, malloc)) { - ctx = found->second; - } - assert(isOriginalBlock(*ctx.Block)); - - bool inLoop; - if (ctx.ForceSingleIteration) { - inLoop = true; - ctx.ForceSingleIteration = false; - } else { - LoopContext lc; - inLoop = getContext(ctx.Block, lc); - } - - if (!inLoop) { - ret->setName(malloc->getName() + "_fromtape"); - if (omp) { - Value *tid = ompThreadId(); - Value *tPtr = BuilderQ.CreateInBoundsGEP(malloc->getType(), ret, - ArrayRef(tid)); - ret = BuilderQ.CreateLoad(malloc->getType(), tPtr); - } - } else { - if (idx >= 0) - erase(cast(ret)); - IRBuilder<> entryBuilder(inversionAllocs); - entryBuilder.setFastMathFlags(getFast()); - ret = (idx < 0) ? tape - : entryBuilder.CreateExtractValue(tape, {(unsigned)idx}); - - assert(malloc); - - Type *innerType = nullptr; - -#if LLVM_VERSION_MAJOR < 17 -#if LLVM_VERSION_MAJOR >= 15 - if (ret->getContext().supportsTypedPointers()) { -#endif - innerType = ret->getType(); - for (size_t i = 0, - limit = getSubLimits( - /*inForwardPass*/ true, nullptr, - LimitContext( - /*ReverseLimit*/ reverseBlocks.size() > 0, - BuilderQ.GetInsertBlock())) - .size(); - i < limit; ++i) { - if (!isa(innerType)) { - llvm::errs() << "mod: " - << *BuilderQ.GetInsertBlock()->getParent()->getParent() - << "\n"; - llvm::errs() << "fn: " << *BuilderQ.GetInsertBlock()->getParent() - << "\n"; - llvm::errs() << "bq insertblock: " << *BuilderQ.GetInsertBlock() - << "\n"; - llvm::errs() << "ret: " << *ret << " type: " << *ret->getType() - << "\n"; - llvm::errs() << "innerType: " << *innerType << "\n"; - if (malloc) - llvm::errs() << " malloc: " << *malloc << " i=" << i - << " / lim = " << limit << "\n"; - } - assert(isa(innerType)); - innerType = innerType->getPointerElementType(); - } -#if LLVM_VERSION_MAJOR >= 15 - } else { - if (EfficientBoolCache && malloc->getType()->isIntegerTy() && - cast(malloc->getType())->getBitWidth() == 1) - innerType = Type::getInt8Ty(malloc->getContext()); - else - innerType = malloc->getType(); - } -#endif -#else - if (EfficientBoolCache && malloc->getType()->isIntegerTy() && - cast(malloc->getType())->getBitWidth() == 1) - innerType = Type::getInt8Ty(malloc->getContext()); - else - innerType = malloc->getType(); -#endif - - if (EfficientBoolCache && malloc->getType()->isIntegerTy() && - cast(malloc->getType())->getBitWidth() == 1 && - innerType != ret->getType()) { - assert(innerType == Type::getInt8Ty(malloc->getContext())); - } else { - if (innerType != malloc->getType()) { - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << "innerType: " << *innerType << "\n"; - llvm::errs() << "malloc->getType(): " << *malloc->getType() << "\n"; - llvm::errs() << "ret: " << *ret << " - " << *ret->getType() << "\n"; - llvm::errs() << "malloc: " << *malloc << "\n"; - assert(0 && "illegal loop cache type"); - llvm_unreachable("illegal loop cache type"); - } - } - - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - BuilderQ.GetInsertBlock()); - AllocaInst *cache = - createCacheForScope(lctx, innerType, "mdyncache_fromtape", - ((DiffeGradientUtils *)this)->FreeMemory, false); - assert(malloc); - bool isi1 = malloc->getType()->isIntegerTy() && - cast(malloc->getType())->getBitWidth() == 1; - assert(isa(cache->getType())); -#if LLVM_VERSION_MAJOR < 17 -#if LLVM_VERSION_MAJOR >= 15 - if (cache->getContext().supportsTypedPointers()) { -#endif - assert(cache->getType()->getPointerElementType() == ret->getType()); -#if LLVM_VERSION_MAJOR >= 15 - } -#endif -#endif - entryBuilder.CreateStore(ret, cache); - - auto v = - lookupValueFromCache(innerType, /*forwardPass*/ true, BuilderQ, lctx, - cache, isi1, /*available*/ ValueToValueMapTy()); - if (malloc) { - assert(v->getType() == malloc->getType()); - } - insert_or_assign(scopeMap, v, - std::make_pair(AssertingVH(cache), ctx)); - ret = cast(v); - } - - if (malloc && !isa(malloc)) { - if (malloc->getType() != ret->getType()) { - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *malloc << "\n"; - llvm::errs() << *ret << "\n"; - } - assert(malloc->getType() == ret->getType()); - - if (replace) { - auto found = newToOriginalFn.find(malloc); - if (found != newToOriginalFn.end()) { - Value *orig = found->second; - originalToNewFn[orig] = ret; - newToOriginalFn.erase(malloc); - newToOriginalFn[ret] = orig; - } - } - - if (auto found = findInMap(scopeMap, malloc)) { - // There already exists an alloaction for this, we should fully remove - // it - if (!inLoop) { - - // Remove stores into - SmallVector stores( - scopeInstructions[found->first].begin(), - scopeInstructions[found->first].end()); - scopeInstructions.erase(found->first); - for (int i = stores.size() - 1; i >= 0; i--) { - erase(stores[i]); - } - - SmallVector users; - for (auto u : found->first->users()) { - users.push_back(u); - } - for (auto u : users) { - if (auto li = dyn_cast(u)) { - IRBuilder<> lb(li); - if (replace) { - - Value *replacewith = - (idx < 0) ? tape - : lb.CreateExtractValue(tape, {(unsigned)idx}); - if (!inLoop && omp) { - Value *tid = ompThreadId(); - Value *tPtr = lb.CreateInBoundsGEP(li->getType(), replacewith, - ArrayRef(tid)); - replacewith = lb.CreateLoad(li->getType(), tPtr); - } - if (li->getType() != replacewith->getType()) { - llvm::errs() << " oldFunc: " << *oldFunc << "\n"; - llvm::errs() << " newFunc: " << *newFunc << "\n"; - llvm::errs() << " malloc: " << *malloc << "\n"; - llvm::errs() << " li: " << *li << "\n"; - llvm::errs() << " u: " << *u << "\n"; - llvm::errs() << " replacewith: " << *replacewith - << " idx=" << idx << " - tape=" << *tape << "\n"; - } - assert(li->getType() == replacewith->getType()); - li->replaceAllUsesWith(replacewith); - } else { - auto phi = - lb.CreatePHI(li->getType(), 0, li->getName() + "_cfrphi"); - unwrappedLoads[phi] = malloc; - li->replaceAllUsesWith(phi); - } - erase(li); - } else { - llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << "malloc: " << *malloc << "\n"; - llvm::errs() << "scopeMap[malloc]: " << *found->first << "\n"; - llvm::errs() << "u: " << *u << "\n"; - assert(0 && "illegal use for out of loop scopeMap1"); - } - } - - { - AllocaInst *preerase = found->first; - scopeMap.erase(malloc); - erase(preerase); - } - } else { - // Remove allocations for scopealloc since it is already allocated - // by the augmented forward pass - // Remove stores into - SmallVector stores( - scopeInstructions[found->first].begin(), - scopeInstructions[found->first].end()); - scopeInstructions.erase(found->first); - scopeAllocs.erase(found->first); - for (int i = stores.size() - 1; i >= 0; i--) { - erase(stores[i]); - } - - // Remove frees - SmallVector tofree(scopeFrees[found->first].begin(), - scopeFrees[found->first].end()); - scopeFrees.erase(found->first); - for (auto freeinst : tofree) { - // This deque contains a list of operations - // we can erasing upon erasing the free (and so on). - // Since multiple operations can have the same operand, - // this deque can contain the same value multiple times. - // To remedy this we use a tracking value handle which will - // be set to null when erased. - std::deque ops = {freeinst->getArgOperand(0)}; - erase(freeinst); - - while (ops.size()) { - auto z = dyn_cast_or_null(ops[0]); - ops.pop_front(); - if (z && z->getNumUses() == 0 && !z->isUsedByMetadata()) { - for (unsigned i = 0; i < z->getNumOperands(); ++i) { - ops.push_back(z->getOperand(i)); - } - erase(z); - } - } - } - - // uses of the alloc - SmallVector users; - for (auto u : found->first->users()) { - users.push_back(u); - } - for (auto u : users) { - if (auto li = dyn_cast(u)) { - // even with replace off, this can be replaced - // as since we're in a loop this load is a load of cache - // not of the final value (thereby overwriting the new - // inst - IRBuilder<> lb(li); - auto replacewith = - (idx < 0) ? tape - : lb.CreateExtractValue(tape, {(unsigned)idx}); - li->replaceAllUsesWith(replacewith); - erase(li); - } else { - llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << "malloc: " << *malloc << "\n"; - llvm::errs() << "scopeMap[malloc]: " << *found->first << "\n"; - llvm::errs() << "u: " << *u << "\n"; - assert(0 && "illegal use for out of loop scopeMap2"); - } - } - - AllocaInst *preerase = found->first; - scopeMap.erase(malloc); - if (replace) - erase(preerase); - } - } - if (replace) - cast(malloc)->replaceAllUsesWith(ret); - ret->takeName(malloc); - if (replace) { - auto malloci = cast(malloc); - if (malloci == &*BuilderQ.GetInsertPoint()) { - BuilderQ.SetInsertPoint(malloci->getNextNode()); - } - erase(malloci); - } - } - return ret; - } else { - assert(malloc); - - assert(idx >= 0 && (unsigned)idx == addedTapeVals.size()); - - if (isa(malloc)) { - addedTapeVals.push_back(malloc); - return malloc; - } - - LimitContext ctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - BuilderQ.GetInsertBlock()); - if (auto inst = dyn_cast(malloc)) - ctx = LimitContext(/*ReverseLimit*/ reverseBlocks.size() > 0, - inst->getParent()); - if (auto found = findInMap(scopeMap, malloc)) { - ctx = found->second; - } - - bool inLoop; - - if (ctx.ForceSingleIteration) { - inLoop = true; - ctx.ForceSingleIteration = false; - } else { - LoopContext lc; - inLoop = getContext(ctx.Block, lc); - } - - if (!inLoop) { - Value *toStoreInTape = malloc; - if (omp) { - Value *numThreads = ompNumThreads(); - Value *tid = ompThreadId(); - IRBuilder<> entryBuilder(inversionAllocs); - - auto firstallocation = - CreateAllocation(entryBuilder, malloc->getType(), numThreads, - malloc->getName() + "_malloccache"); - Value *tPtr = entryBuilder.CreateInBoundsGEP( - malloc->getType(), firstallocation, ArrayRef(tid)); - if (auto inst = dyn_cast(malloc)) { - entryBuilder.SetInsertPoint(inst->getNextNode()); - } - entryBuilder.CreateStore(malloc, tPtr); - toStoreInTape = firstallocation; - } - addedTapeVals.push_back(toStoreInTape); - return malloc; - } - - ensureLookupCached( - cast(malloc), - /*shouldFree=*/reverseBlocks.size() > 0, - /*scope*/ nullptr, - cast(malloc)->getMetadata(LLVMContext::MD_tbaa)); - auto found2 = scopeMap.find(malloc); - assert(found2 != scopeMap.end()); - assert(found2->second.first); - - Value *toadd; - toadd = scopeAllocs[found2->second.first][0]; - for (auto u : toadd->users()) { - if (auto ci = dyn_cast(u)) { - toadd = ci; - break; - } - } - - // llvm::errs() << " malloc: " << *malloc << "\n"; - // llvm::errs() << " toadd: " << *toadd << "\n"; -#if LLVM_VERSION_MAJOR < 17 -#if LLVM_VERSION_MAJOR >= 15 - if (toadd->getContext().supportsTypedPointers()) { -#endif - Type *innerType = toadd->getType(); - for (size_t i = 0, - limit = getSubLimits( - /*inForwardPass*/ true, nullptr, - LimitContext( - /*ReverseLimit*/ reverseBlocks.size() > 0, - BuilderQ.GetInsertBlock())) - .size(); - i < limit; ++i) { - innerType = innerType->getPointerElementType(); - } - if (EfficientBoolCache && malloc->getType()->isIntegerTy() && - toadd->getType() != innerType && - cast(malloc->getType())->getBitWidth() == 1) { - assert(innerType == Type::getInt8Ty(toadd->getContext())); - } else { - if (innerType != malloc->getType()) { - llvm::errs() << "oldFunc:" << *oldFunc << "\n"; - llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << " toadd: " << *toadd << "\n"; - llvm::errs() << "innerType: " << *innerType << "\n"; - llvm::errs() << "malloc: " << *malloc << "\n"; - } - assert(innerType == malloc->getType()); - } -#if LLVM_VERSION_MAJOR >= 15 - } -#endif -#endif - addedTapeVals.push_back(toadd); - return malloc; - } - llvm::errs() - << "Fell through on cacheForReverse. This should never happen.\n"; - assert(false); -} - -BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { - auto header = lc.header; - SmallPtrSet loopRematerializations; - SmallPtrSet loopReallocations; - SmallPtrSet loopShadowReallocations; - SmallPtrSet loopShadowZeroInits; - SmallPtrSet loopShadowRematerializations; - Loop *origLI = nullptr; - for (auto pair : rematerializableAllocations) { - if (pair.second.LI && - getNewFromOriginal(pair.second.LI->getHeader()) == header) { - bool rematerialized = false; - std::map Seen; - for (auto pair : knownRecomputeHeuristic) - if (!pair.second) - Seen[UsageKey(pair.first, QueryType::Primal)] = false; - - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(this, pair.first, mode, Seen, - notForAnalysis)) { - rematerialized = true; - } - if (rematerialized) { - if (auto inst = dyn_cast(pair.first)) - if (pair.second.LI->contains(inst->getParent())) { - loopReallocations.insert(inst); - } - for (auto I : pair.second.stores) - loopRematerializations.insert(I); - origLI = pair.second.LI; - } - } - } - for (auto pair : backwardsOnlyShadows) { - if (pair.second.LI && - getNewFromOriginal(pair.second.LI->getHeader()) == header) { - if (auto inst = dyn_cast(pair.first)) { - bool restoreStores = false; - if (pair.second.LI->contains(inst->getParent())) { - // TODO later make it so primalInitialize can be restored - // rather than cached from primal - if (!pair.second.primalInitialize) { - loopShadowReallocations.insert(inst); - restoreStores = true; - } - } else { - // if (pair.second.primalInitialize) { - // loopShadowZeroInits.insert(inst); - //} - restoreStores = true; - } - if (restoreStores) { - for (auto I : pair.second.stores) { - loopShadowRematerializations.insert(I); - } - } - origLI = pair.second.LI; - } - } - } - if (loopRematerializations.size() != 0 || loopReallocations.size() != 0 || - loopShadowRematerializations.size() != 0 || - loopShadowReallocations.size() != 0 || loopShadowZeroInits.size() != 0) { - auto found = rematerializedLoops_cache.find(header); - if (found != rematerializedLoops_cache.end()) { - return found->second; - } - - BasicBlock *enterB = - BasicBlock::Create(header->getContext(), "remat_enter", newFunc); - rematerializedLoops_cache[header] = enterB; - std::map origToNewForward; - for (auto B : origLI->getBlocks()) { - BasicBlock *newB = BasicBlock::Create( - B->getContext(), "remat_" + header->getName() + "_" + B->getName(), - newFunc); - origToNewForward[B] = newB; - reverseBlockToPrimal[newB] = getNewFromOriginal(B); - if (B == origLI->getHeader()) { - IRBuilder<> NB(newB); - for (auto inst : loopShadowZeroInits) { - auto anti = lookupM(invertPointerM(inst, NB), NB); - StringRef funcName; - SmallVector args; - if (auto orig = dyn_cast(inst)) { -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : orig->args()) -#else - for (auto &arg : orig->arg_operands()) -#endif - { - args.push_back(lookupM(getNewFromOriginal(arg), NB)); - } - funcName = getFuncNameFromCall(orig); - } else if (auto AI = dyn_cast(inst)) { - funcName = "malloc"; - Value *sz = lookupM(getNewFromOriginal(AI->getArraySize()), NB); - - auto ci = ConstantInt::get( - sz->getType(), - B->getParent() - ->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(AI->getAllocatedType()) / - 8); - sz = NB.CreateMul(sz, ci); - args.push_back(sz); - } - assert(funcName.size()); - - applyChainRule( - NB, - [&](Value *anti) { - zeroKnownAllocation(NB, anti, args, funcName, TLI, - dyn_cast(inst)); - }, - anti); - } - } - } - - ValueToValueMapTy available; - - { - IRBuilder<> NB(enterB); - NB.CreateBr(origToNewForward[origLI->getHeader()]); - } - - std::function handleLoop = [&](Loop *OL, bool subLoop) { - if (subLoop) { - auto Header = OL->getHeader(); - IRBuilder<> NB(origToNewForward[Header]); - LoopContext flc; - getContext(getNewFromOriginal(Header), flc); - - auto iv = NB.CreatePHI(flc.var->getType(), 2, "fiv"); - auto inc = NB.CreateAdd(iv, ConstantInt::get(iv->getType(), 1)); - - for (auto PH : predecessors(Header)) { - if (notForAnalysis.count(PH)) - continue; - - if (OL->contains(PH)) - iv->addIncoming(inc, origToNewForward[PH]); - else - iv->addIncoming(ConstantInt::get(iv->getType(), 0), - origToNewForward[PH]); - } - available[flc.var] = iv; - available[flc.incvar] = inc; - } - for (auto SL : OL->getSubLoops()) - handleLoop(SL, /*subLoop*/ true); - }; - handleLoop(origLI, /*subLoop*/ false); - - for (auto B : origLI->getBlocks()) { - auto newB = origToNewForward[B]; - IRBuilder<> NB(newB); - - // TODO fill available with relevant IV's surrounding and - // IV's of inner loop phi's - - for (auto &I : *B) { - // Only handle store, memset, and julia.write_barrier - if (loopRematerializations.count(&I)) { - if (auto SI = dyn_cast(&I)) { - auto ts = NB.CreateStore( - lookupM(getNewFromOriginal(SI->getValueOperand()), NB, - available), - lookupM(getNewFromOriginal(SI->getPointerOperand()), NB, - available)); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - ToCopy2.push_back(LLVMContext::MD_alias_scope); - ts->copyMetadata(*SI, ToCopy2); - ts->setAlignment(SI->getAlign()); - ts->setVolatile(SI->isVolatile()); - ts->setOrdering(SI->getOrdering()); - ts->setSyncScopeID(SI->getSyncScopeID()); - ts->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - } else if (auto CI = dyn_cast(&I)) { - StringRef funcName = getFuncNameFromCall(CI); - if (funcName == "enzyme_zerotype") - continue; - if (funcName == "julia.write_barrier" || - funcName == "julia.write_barrier_binding" || - isa(&I) || isa(&I)) { - - // TODO - SmallVector args; -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : CI->args()) -#else - for (auto &arg : CI->arg_operands()) -#endif - args.push_back(lookupM(getNewFromOriginal(arg), NB, available)); - - SmallVector BundleTypes(args.size(), - ValueType::Primal); - - auto Defs = getInvertedBundles(CI, BundleTypes, NB, - /*lookup*/ true, available); - auto cal = NB.CreateCall(CI->getFunctionType(), - CI->getCalledOperand(), args, Defs); - cal->setAttributes(CI->getAttributes()); - cal->setCallingConv(CI->getCallingConv()); - cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - } else { - assert(isDeallocationFunction(funcName, TLI)); - continue; - } - } else { - assert(0 && "unhandlable loop rematerialization instruction"); - } - } else if (loopReallocations.count(&I)) { - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - &newFunc->getEntryBlock()); - - auto inst = getNewFromOriginal((Value *)&I); - - auto found = scopeMap.find(inst); - if (found == scopeMap.end()) { - AllocaInst *cache = createCacheForScope( - lctx, inst->getType(), inst->getName(), /*shouldFree*/ true); - assert(cache); - found = insert_or_assign( - scopeMap, inst, - std::pair, LimitContext>(cache, lctx)); - } - auto cache = found->second.first; - if (auto MD = hasMetadata(&I, "enzyme_fromstack")) { - auto replacement = NB.CreateAlloca( - Type::getInt8Ty(I.getContext()), - lookupM(getNewFromOriginal(I.getOperand(0)), NB, available)); - for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", - "enzymejl_allocart"}) - if (auto M = I.getMetadata(MD)) - replacement->setMetadata(MD, M); - auto Alignment = - cast( - cast(MD->getOperand(0))->getValue()) - ->getLimitedValue(); - replacement->setAlignment(Align(Alignment)); - replacement->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - storeInstructionInCache(lctx, NB, replacement, cache); - } else if (auto CI = dyn_cast(&I)) { - SmallVector args; -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : CI->args()) -#else - for (auto &arg : CI->arg_operands()) -#endif - args.push_back(lookupM(getNewFromOriginal(arg), NB, available)); - - SmallVector BundleTypes(args.size(), - ValueType::Primal); - - auto Defs = getInvertedBundles(CI, BundleTypes, NB, - /*lookup*/ true, available); - auto cal = NB.CreateCall(CI->getCalledFunction(), args, Defs); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - ToCopy2.push_back(LLVMContext::MD_alias_scope); - cal->copyMetadata(*CI, ToCopy2); - cal->setName("remat_" + CI->getName()); - cal->setAttributes(CI->getAttributes()); - cal->setCallingConv(CI->getCallingConv()); - cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - storeInstructionInCache(lctx, NB, cal, cache); - } else { - llvm::errs() << " realloc: " << I << "\n"; - llvm_unreachable("Unknown loop reallocation"); - } - } - if (loopShadowRematerializations.count(&I)) { - if (auto SI = dyn_cast(&I)) { - Value *orig_ptr = SI->getPointerOperand(); - Value *orig_val = SI->getValueOperand(); - Type *valType = orig_val->getType(); - assert(!isConstantValue(orig_ptr)); - - auto &DL = newFunc->getParent()->getDataLayout(); - - bool constantval = isConstantValue(orig_val) || - parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); - - // TODO allow recognition of other types that could contain - // pointers [e.g. {void*, void*} or <2 x i64> ] - auto storeSize = DL.getTypeSizeInBits(valType) / 8; - - //! Storing a floating point value - Type *FT = nullptr; - if (valType->isFPOrFPVectorTy()) { - FT = valType->getScalarType(); - } else if (!valType->isPointerTy()) { - if (looseTypeAnalysis) { - auto fp = TR.firstPointer(storeSize, orig_ptr, &I, - /*errifnotfound*/ false, - /*pointerIntSame*/ true); - if (fp.isKnown()) { - FT = fp.isFloat(); - llvm::errs() << "assuming type as " << *FT - << " for store: " << I << "\n"; - } else if (isa(orig_val) || - valType->isIntOrIntVectorTy()) { - llvm::errs() - << "assuming type as integral for store: " << I << "\n"; - FT = nullptr; - } else { - TR.firstPointer(storeSize, orig_ptr, &I, - /*errifnotfound*/ true, - /*pointerIntSame*/ true); - llvm::errs() << "cannot deduce type of store " << I << "\n"; - assert(0 && "cannot deduce"); - } - } else { - FT = TR.firstPointer(storeSize, orig_ptr, &I, - /*errifnotfound*/ true, - /*pointerIntSame*/ true) - .isFloat(); - } - } - if (!FT) { - Value *valueop = nullptr; - if (constantval) { - Value *val = - lookupM(getNewFromOriginal(orig_val), NB, available); - valueop = val; - if (getWidth() > 1) { - Value *array = UndefValue::get(getShadowType(val->getType())); - for (unsigned i = 0; i < getWidth(); ++i) { - array = NB.CreateInsertValue(array, val, {i}); - } - valueop = array; - } - } else { - valueop = lookupM(invertPointerM(orig_val, NB), NB, available); - } - SmallVector prevScopes; - if (auto prev = SI->getMetadata(LLVMContext::MD_alias_scope)) { - for (auto &M : cast(prev)->operands()) { - prevScopes.push_back(M); - } - } - SmallVector prevNoAlias; - if (auto prev = SI->getMetadata(LLVMContext::MD_noalias)) { - for (auto &M : cast(prev)->operands()) { - prevNoAlias.push_back(M); - } - } - auto align = SI->getAlign(); - setPtrDiffe(SI, orig_ptr, valueop, NB, align, 0, storeSize, - SI->isVolatile(), SI->getOrdering(), - SI->getSyncScopeID(), - /*mask*/ nullptr, prevNoAlias, prevScopes); - } - // TODO shadow memtransfer - } else if (auto MS = dyn_cast(&I)) { - if (!isConstantValue(MS->getArgOperand(0))) { - Value *args[4] = { - lookupM(invertPointerM(MS->getArgOperand(0), NB), NB, - available), - lookupM(getNewFromOriginal(MS->getArgOperand(1)), NB, - available), - lookupM(getNewFromOriginal(MS->getArgOperand(2)), NB, - available), - lookupM(getNewFromOriginal(MS->getArgOperand(3)), NB, - available)}; - - ValueType BundleTypes[4] = {ValueType::Shadow, ValueType::Primal, - ValueType::Primal, ValueType::Primal}; - auto Defs = getInvertedBundles(MS, BundleTypes, NB, - /*lookup*/ true, available); - auto cal = NB.CreateCall(MS->getCalledFunction(), args, Defs); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - ToCopy2.push_back(LLVMContext::MD_alias_scope); - cal->copyMetadata(*MS, ToCopy2); - cal->setAttributes(MS->getAttributes()); - cal->setCallingConv(MS->getCallingConv()); - cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - } - } else if (auto CI = dyn_cast(&I)) { - StringRef funcName = getFuncNameFromCall(CI); - if (funcName == "julia.write_barrier" || - funcName == "julia.write_barrier_binding") { - - // TODO - SmallVector args; -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : CI->args()) -#else - for (auto &arg : CI->arg_operands()) -#endif - if (!isConstantValue(arg)) - args.push_back( - lookupM(invertPointerM(arg, NB), NB, available)); - - if (args.size()) { - SmallVector BundleTypes(args.size(), - ValueType::Primal); - - auto Defs = getInvertedBundles(CI, BundleTypes, NB, - /*lookup*/ true, available); - auto cal = NB.CreateCall(CI->getFunctionType(), - CI->getCalledOperand(), args, Defs); - cal->setAttributes(CI->getAttributes()); - cal->setCallingConv(CI->getCallingConv()); - cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - } - } else { - assert(isDeallocationFunction(funcName, TLI)); - continue; - } - } else { - assert(0 && - "unhandlable loop shadow rematerialization instruction"); - } - } else if (loopShadowReallocations.count(&I)) { - - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - &newFunc->getEntryBlock()); - auto ipfound = invertedPointers.find(&I); - PHINode *placeholder = cast(&*ipfound->second); - - auto found = scopeMap.find(placeholder); - if (found == scopeMap.end()) { - AllocaInst *cache = createCacheForScope( - lctx, placeholder->getType(), placeholder->getName(), - /*shouldFree*/ true); - assert(cache); - Value *placeholder_tmp = placeholder; - found = insert_or_assign( - scopeMap, placeholder_tmp, - std::pair, LimitContext>(cache, lctx)); - } - auto cache = found->second.first; - Value *anti = nullptr; - - if (auto orig = dyn_cast(&I)) { - StringRef funcName = getFuncNameFromCall(orig); - assert(funcName.size()); - - auto dbgLoc = getNewFromOriginal(orig)->getDebugLoc(); - - SmallVector args; -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : orig->args()) -#else - for (auto &arg : orig->arg_operands()) -#endif - { - args.push_back(lookupM(getNewFromOriginal(arg), NB)); - } - - placeholder->setName(""); - if (shadowHandlers.find(funcName) != shadowHandlers.end()) { - - anti = shadowHandlers[funcName](NB, orig, args, this); - } else { - auto rule = [&]() { - Value *anti = NB.CreateCall(orig->getFunctionType(), - orig->getCalledOperand(), args, - orig->getName() + "'mi"); - cast(anti)->setAttributes(orig->getAttributes()); - cast(anti)->setCallingConv(orig->getCallingConv()); - cast(anti)->setDebugLoc( - getNewFromOriginal(I.getDebugLoc())); - - cast(anti)->addAttribute(AttributeList::ReturnIndex, - Attribute::NoAlias); - cast(anti)->addAttribute(AttributeList::ReturnIndex, - Attribute::NonNull); - return anti; - }; - - anti = applyChainRule(orig->getType(), NB, rule); - - if (auto MD = hasMetadata(orig, "enzyme_fromstack")) { - auto rule = [&](Value *anti) { - AllocaInst *replacement = NB.CreateAlloca( - Type::getInt8Ty(orig->getContext()), args[0]); - for (auto MD : {"enzyme_active", "enzyme_inactive", - "enzyme_type", "enzymejl_allocart"}) - if (auto M = I.getMetadata(MD)) - replacement->setMetadata(MD, M); - replacement->takeName(anti); - auto Alignment = cast(cast( - MD->getOperand(0)) - ->getValue()) - ->getLimitedValue(); - replacement->setAlignment(Align(Alignment)); - replacement->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - return replacement; - }; - - Value *replacement = applyChainRule( - Type::getInt8Ty(orig->getContext()), NB, rule, anti); - - replaceAWithB(cast(anti), replacement); - erase(cast(anti)); - anti = replacement; - } - - applyChainRule( - NB, - [&](Value *anti) { - zeroKnownAllocation(NB, anti, args, funcName, TLI, orig); - }, - anti); - } - } else { - llvm_unreachable("Unknown shadow rematerialization value"); - } - assert(anti); - storeInstructionInCache(lctx, NB, anti, cache); - } - } - - llvm::SmallPtrSet origExitBlocks; - getExitBlocks(origLI, origExitBlocks); - // Remap a branch to the header to enter the incremented - // reverse of that block. - auto remap = [&](BasicBlock *rB) { - // Remap of an exit branch is to go to the reverse - // exiting block. - if (origExitBlocks.count(rB)) { - return reverseBlocks[getNewFromOriginal(B)].front(); - } - // Reverse of an incrementing branch is go to the - // reverse of the branching block. - if (rB == origLI->getHeader()) - return reverseBlocks[getNewFromOriginal(B)].front(); - auto found = origToNewForward.find(rB); - if (found == origToNewForward.end()) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *origLI << "\n"; - llvm::errs() << *rB << "\n"; - } - assert(found != origToNewForward.end()); - return found->second; - }; - - // TODO clone terminator - auto TI = B->getTerminator(); - assert(TI); - if (notForAnalysis.count(B)) { - NB.CreateUnreachable(); - } else if (auto BI = dyn_cast(TI)) { - if (BI->isUnconditional()) { - if (notForAnalysis.count(BI->getSuccessor(0))) - NB.CreateUnreachable(); - else - NB.CreateBr(remap(BI->getSuccessor(0))); - } else { - if (notForAnalysis.count(BI->getSuccessor(0))) { - if (notForAnalysis.count(BI->getSuccessor(1))) { - NB.CreateUnreachable(); - } else { - NB.CreateBr(remap(BI->getSuccessor(1))); - } - } else if (notForAnalysis.count(BI->getSuccessor(1))) { - NB.CreateBr(remap(BI->getSuccessor(0))); - } else { - NB.CreateCondBr( - lookupM(getNewFromOriginal(BI->getCondition()), NB, available), - remap(BI->getSuccessor(0)), remap(BI->getSuccessor(1))); - } - } - } else if (auto SI = dyn_cast(TI)) { - BasicBlock *newDest = nullptr; - if (!notForAnalysis.count(SI->getDefaultDest())) - newDest = remap(SI->getDefaultDest()); - else { - for (auto cas : SI->cases()) { - if (!notForAnalysis.count(cas.getCaseSuccessor())) - newDest = remap(cas.getCaseSuccessor()); - break; - } - } - if (!newDest) { - NB.CreateUnreachable(); - } else { - auto NSI = NB.CreateSwitch( - lookupM(getNewFromOriginal(SI->getCondition()), NB, available), - newDest); - for (auto cas : SI->cases()) { - if (!notForAnalysis.count(cas.getCaseSuccessor())) - NSI->addCase(cas.getCaseValue(), remap(cas.getCaseSuccessor())); - } - } - } else { - assert(isa(TI)); - NB.CreateUnreachable(); - } - // Fixup phi nodes that may have their predecessors now changed by - // the phi unwrapping - if (!notForAnalysis.count(B) && - NB.GetInsertBlock() != origToNewForward[B]) { - for (auto S0 : successors(B)) { - if (!origToNewForward.count(S0)) - continue; - auto S = origToNewForward[S0]; - assert(S); - for (auto I = S->begin(), E = S->end(); I != E; ++I) { - PHINode *orig = dyn_cast(&*I); - if (orig == nullptr) - break; - for (unsigned Op = 0, NumOps = orig->getNumOperands(); Op != NumOps; - ++Op) - if (orig->getIncomingBlock(Op) == origToNewForward[B]) - orig->setIncomingBlock(Op, NB.GetInsertBlock()); - } - } - } - } - return enterB; - } - return nullptr; -} - -/// Given an edge from BB to branchingBlock get the corresponding block to -/// branch to in the reverse pass -BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, - BasicBlock *branchingBlock) { - assert(BB); - // BB should be a forward pass block, assert that - if (reverseBlocks.find(BB) == reverseBlocks.end()) { - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << "BB: " << *BB << "\n"; - llvm::errs() << "branchingBlock: " << *branchingBlock << "\n"; - } - assert(reverseBlocks.find(BB) != reverseBlocks.end()); - assert(reverseBlocks.find(branchingBlock) != reverseBlocks.end()); - LoopContext lc; - bool inLoop = getContext(BB, lc); - - LoopContext branchingContext; - bool inLoopContext = getContext(branchingBlock, branchingContext); - - if (!inLoop) - return reverseBlocks[BB].front(); - - auto tup = std::make_tuple(BB, branchingBlock); - if (newBlocksForLoop_cache.find(tup) != newBlocksForLoop_cache.end()) - return newBlocksForLoop_cache[tup]; - - // If we're reversing a latch edge. - bool incEntering = inLoopContext && branchingBlock == lc.header && - lc.header == branchingContext.header; - - auto L = LI.getLoopFor(BB); - auto latches = getLatches(L, lc.exitBlocks); - // If we're reverseing a loop exit. - bool exitEntering = - std::find(latches.begin(), latches.end(), BB) != latches.end() && - std::find(lc.exitBlocks.begin(), lc.exitBlocks.end(), branchingBlock) != - lc.exitBlocks.end(); - - // It is illegal to be both an increment into a loop, and exiting the loop. - assert(!(incEntering && exitEntering)); - - // If we're re-entering a loop, prepare a loop-level forward pass to - // rematerialize any loop-scope rematerialization. - - if (incEntering) { - BasicBlock *resumeblock = reverseBlocks[BB].front(); - auto tmp_resumeblock = prepRematerializedLoopEntry(lc); - if (tmp_resumeblock) - resumeblock = tmp_resumeblock; - BasicBlock *incB = BasicBlock::Create( - BB->getContext(), "inc" + reverseBlocks[lc.header].front()->getName(), - BB->getParent()); - incB->moveAfter(reverseBlocks[lc.header].back()); - - IRBuilder<> tbuild(incB); - - Value *av = tbuild.CreateLoad(lc.var->getType(), lc.antivaralloc); - Value *sub = tbuild.CreateAdd(av, ConstantInt::get(av->getType(), -1), "", - /*NUW*/ false, /*NSW*/ true); - tbuild.CreateStore(sub, lc.antivaralloc); - tbuild.CreateBr(resumeblock); - return newBlocksForLoop_cache[tup] = incB; - } - - if (exitEntering) { - SmallVector exitingContexts = {lc}; - - auto L2 = L; - while ((L2 = L2->getParentLoop())) { - LoopContext lc2; - bool inLoop = getContext(L2->getHeader(), lc2); - if (!inLoop) - break; - - auto latches2 = getLatches(L2, lc2.exitBlocks); - - // If we're reverseing a loop exit. - bool exitEntering2 = - std::find(latches2.begin(), latches2.end(), BB) != latches2.end() && - std::find(lc2.exitBlocks.begin(), lc2.exitBlocks.end(), - branchingBlock) != lc2.exitBlocks.end(); - if (exitEntering2) { - exitingContexts.push_back(lc2); - } else - break; - } - - BasicBlock *resumeblock = reverseBlocks[BB].front(); - BasicBlock *prevBlock = reverseBlocks[branchingBlock].back(); - - BasicBlock *outerMerge = nullptr; - - BasicBlock *incB = BasicBlock::Create( - BB->getContext(), - "merge" + reverseBlocks[lc.header].front()->getName() + "_" + - branchingBlock->getName(), - BB->getParent()); - if (!outerMerge) - outerMerge = incB; - incB->moveAfter(prevBlock); - - IRBuilder<> tbuild(prevBlock); - - SmallVector, 1> lims; - ValueToValueMapTy available; - for (auto I = exitingContexts.rbegin(), E = exitingContexts.rend(); I != E; - I++) { - auto &lc = *I; - auto L = LI.getLoopFor(lc.header); - Value *lim = nullptr; - if (lc.dynamic && assumeDynamicLoopOfSizeOne(L)) { - lim = ConstantInt::get(lc.var->getType(), 0); - } else if (lc.dynamic) { - // Must be in a reverse pass fashion for a lookup to index bound to be - // legal - assert(/*ReverseLimit*/ reverseBlocks.size() > 0); - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - lc.preheader); - lim = lookupValueFromCache(lc.var->getType(), - /*forwardPass*/ false, tbuild, lctx, - getDynamicLoopLimit(L), - /*isi1*/ false, available); - } else { - lim = lookupM(lc.trueLimit, tbuild, available); - } - available[lc.var] = lim; - lims.push_back(std::make_pair(lim, (Value *)lc.antivaralloc)); - } - - tbuild.SetInsertPoint(incB); - for (auto &pair : lims) { - tbuild.CreateStore(pair.first, pair.second); - } - - auto tmp_resumeblock = prepRematerializedLoopEntry(exitingContexts.back()); - if (tmp_resumeblock) - resumeblock = tmp_resumeblock; - - tbuild.CreateBr(resumeblock); - - return newBlocksForLoop_cache[tup] = incB; - } - - return newBlocksForLoop_cache[tup] = reverseBlocks[BB].front(); -} - -void GradientUtils::forceContexts() { - for (auto BB : originalBlocks) { - LoopContext lc; - getContext(BB, lc); - } -} - -bool GradientUtils::legalRecompute(const Value *val, - const ValueToValueMapTy &available, - IRBuilder<> *BuilderM, bool reverse, - bool legalRecomputeCache) const { - { - auto found = available.find(val); - if (found != available.end()) { - if (found->second) - return true; - else { - return false; - } - } - } - - if (isa(val)) - return false; - - if (auto phi = dyn_cast(val)) { - if (auto uiv = hasUninverted(val)) { - if (auto dli = dyn_cast_or_null(uiv)) { - return legalRecompute( - dli, available, BuilderM, - reverse); // TODO ADD && !TR.intType(getOriginal(dli), - // /*mustfind*/false).isPossibleFloat(); - } - if (auto ci = dyn_cast(uiv)) { - auto called = ci->getCalledFunction(); - if (ci->hasFnAttr("enzyme_shouldrecompute") || - (called && called->hasFnAttribute("enzyme_shouldrecompute"))) - return true; - } - if (phi->getNumIncomingValues() == 0) { - return false; - } - } - - auto found = fictiousPHIs.find(const_cast(phi)); - if (found != fictiousPHIs.end()) { - auto orig = found->second; - if (isa(orig)) - return false; - } - - if (phi->getNumIncomingValues() == 0) { - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *phi << "\n"; - } - assert(phi->getNumIncomingValues() != 0); - auto parent = phi->getParent(); - struct { - Function *func; - const LoopInfo &FLI; - } options[2] = {{newFunc, LI}, {oldFunc, *OrigLI}}; - for (const auto &tup : options) { - if (parent->getParent() == tup.func) { - for (auto &val : phi->incoming_values()) { - if (isPotentialLastLoopValue(val, parent, tup.FLI)) { - return false; - } - } - if (tup.FLI.isLoopHeader(parent)) { - // Currently can only recompute header - // with two incoming values - if (phi->getNumIncomingValues() != 2) - return false; - auto L = tup.FLI.getLoopFor(parent); - - // Only recomputable if non recursive. - SmallPtrSet seen; - SmallVector todo; - for (auto PH : predecessors(parent)) { - // Prior iterations must be recomputable without - // this value. - if (L->contains(PH)) { - if (auto I = - dyn_cast(phi->getIncomingValueForBlock(PH))) - if (L->contains(I->getParent())) - todo.push_back(I); - } - } - - while (todo.size()) { - auto cur = todo.back(); - todo.pop_back(); - if (seen.count(cur)) - continue; - seen.insert(cur); - if (cur == phi) - return false; - for (auto &op : cur->operands()) { - if (auto I = dyn_cast(op)) { - if (L->contains(I->getParent())) - todo.push_back(I); - } - } - } - } - return true; - } - } - return false; - } - - if (isa(val) && - cast(val)->getMetadata("enzyme_mustcache")) { - return false; - } - - // If this is a load from cache already, dont force a cache of this - if (legalRecomputeCache && isa(val) && - CacheLookups.count(cast(val))) { - return true; - } - - // TODO consider callinst here - - if (auto li = dyn_cast(val)) { - - const IntrinsicInst *II; - if (isa(li) || isNVLoad(li) || - ((II = dyn_cast(li)) && - (II->getIntrinsicID() == Intrinsic::masked_load))) { - // If this is an already unwrapped value, legal to recompute again. - if (unwrappedLoads.find(li) != unwrappedLoads.end()) - return legalRecompute(unwrappedLoads.find(li)->second, available, - BuilderM, reverse); - - const Instruction *orig = nullptr; - if (li->getParent()->getParent() == oldFunc) { - orig = li; - } else if (li->getParent()->getParent() == newFunc) { - orig = isOriginal(li); - // todo consider when we pass non original queries - if (orig && !isa(orig)) { - return legalRecompute(orig, available, BuilderM, reverse, - legalRecomputeCache); - } - } else { - llvm::errs() << " newFunc: " << *newFunc << "\n"; - llvm::errs() << " parent: " << *li->getParent()->getParent() << "\n"; - llvm::errs() << " li: " << *li << "\n"; - assert(0 && "illegal load legalRecopmute query"); - } - - if (orig) { - assert(can_modref_map); - auto found = can_modref_map->find(const_cast(orig)); - if (found == can_modref_map->end()) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << "can_modref_map:\n"; - for (auto &pair : *can_modref_map) { - llvm::errs() << " + " << *pair.first << ": " << pair.second - << " of func " - << pair.first->getParent()->getParent()->getName() - << "\n"; - } - llvm::errs() << "couldn't find in can_modref_map: " << *li << " - " - << *orig << " in fn: " - << orig->getParent()->getParent()->getName(); - } - assert(found != can_modref_map->end()); - if (!found->second) - return true; - // if insertion block of this function: - BasicBlock *fwdBlockIfReverse = nullptr; - if (BuilderM) { - fwdBlockIfReverse = BuilderM->GetInsertBlock(); - if (!reverse) { - auto found = reverseBlockToPrimal.find(BuilderM->GetInsertBlock()); - if (found != reverseBlockToPrimal.end()) { - fwdBlockIfReverse = found->second; - reverse = true; - } - } - if (fwdBlockIfReverse->getParent() != oldFunc) - fwdBlockIfReverse = - cast_or_null(isOriginal(fwdBlockIfReverse)); - } - if (mode == DerivativeMode::ReverseModeCombined && fwdBlockIfReverse) { - if (reverse) { - bool failed = false; - allFollowersOf( - const_cast(orig), [&](Instruction *I) -> bool { - if (I->mayWriteToMemory() && - writesToMemoryReadBy( - &TR, *OrigAA, TLI, - /*maybeReader*/ const_cast(orig), - /*maybeWriter*/ I)) { - failed = true; - EmitWarning( - "UncacheableLoad", *orig, "Load must be recomputed ", - *orig, " in reverse_", - BuilderM->GetInsertBlock()->getName(), " due to ", *I); - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (!failed) - return true; - } else { - Instruction *origStart = &*BuilderM->GetInsertPoint(); - do { - if (Instruction *og = isOriginal(origStart)) { - origStart = og; - break; - } - origStart = origStart->getNextNode(); - } while (true); - if (OrigDT->dominates(origStart, const_cast(orig))) { - bool failed = false; - - allInstructionsBetween( - const_cast(this)->LI, origStart, - const_cast(orig), [&](Instruction *I) -> bool { - if (I->mayWriteToMemory() && - writesToMemoryReadBy( - &TR, *OrigAA, TLI, - /*maybeReader*/ const_cast(orig), - /*maybeWriter*/ I)) { - failed = true; - EmitWarning("UncacheableLoad", *orig, - "Load must be recomputed ", *orig, " in ", - BuilderM->GetInsertBlock()->getName(), - " due to ", *I); - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (!failed) - return true; - } - } - } - return false; - } else { - if (auto dli = dyn_cast_or_null(hasUninverted(li))) { - return legalRecompute(dli, available, BuilderM, reverse); - } - - // TODO mark all the explicitly legal nodes (caches, etc) - return true; - llvm::errs() << *li << " orig: " << orig - << " parent: " << li->getParent()->getParent()->getName() - << "\n"; - llvm_unreachable("unknown load to redo!"); - } - } - } - - if (auto ci = dyn_cast(val)) { - auto n = getFuncNameFromCall(const_cast(ci)); - auto called = ci->getCalledFunction(); - Intrinsic::ID ID = Intrinsic::not_intrinsic; - - if (ci->hasFnAttr("enzyme_shouldrecompute") || - (called && called->hasFnAttribute("enzyme_shouldrecompute")) || - isMemFreeLibMFunction(n, &ID) || n == "lgamma_r" || n == "lgammaf_r" || - n == "lgammal_r" || n == "__lgamma_r_finite" || - n == "__lgammaf_r_finite" || n == "__lgammal_r_finite" || n == "tanh" || - n == "tanhf" || n == "__pow_finite" || - n == "julia.pointer_from_objref" || startsWith(n, "enzyme_wrapmpi$$") || - n == "omp_get_thread_num" || n == "omp_get_max_threads") { - return true; - } -#if LLVM_VERSION_MAJOR >= 14 - if (ci->doesNotAccessMemory()) -#else - if (ci->hasFnAttr(Attribute::ReadNone) || - (called && called->hasFnAttribute(Attribute::ReadNone))) -#endif - return true; - if (isPointerArithmeticInst(ci)) - return true; - } - - if (auto inst = dyn_cast(val)) { - if (inst->mayReadOrWriteMemory()) { - return false; - } - } - - return true; -} - -//! Given the option to recompute a value or re-use an old one, return true if -//! it is faster to recompute this value from scratch -bool GradientUtils::shouldRecompute(const Value *val, - const ValueToValueMapTy &available, - IRBuilder<> *BuilderM) { - if (available.count(val)) - return true; - // TODO: remake such that this returns whether a load to a cache is more - // expensive than redoing the computation. - - // If this is a load from cache already, just reload this - if (isa(val) && - cast(val)->getMetadata("enzyme_fromcache")) - return true; - - if (!isa(val)) - return true; - - const Instruction *inst = cast(val); - - if (TapesToPreventRecomputation.count(inst)) - return false; - - if (knownRecomputeHeuristic.find(inst) != knownRecomputeHeuristic.end()) { - return knownRecomputeHeuristic[inst]; - } - if (auto OrigInst = isOriginal(inst)) { - if (knownRecomputeHeuristic.find(OrigInst) != - knownRecomputeHeuristic.end()) { - return knownRecomputeHeuristic[OrigInst]; - } - } - - if (isa(val) || isa(val)) - return true; - - if (EnzymeNewCache && !EnzymeMinCutCache) { - // if this has operands that need to be loaded and haven't already been - // loaded - // TODO, just cache this - for (auto &op : inst->operands()) { - if (!legalRecompute(op, available, BuilderM)) { - - // If this is a load from cache already, dont force a cache of this - if (isa(op) && CacheLookups.count(cast(op))) - continue; - - // If a previously cached this operand, don't let it trigger the - // heuristic for caching this value instead. - if (scopeMap.find(op) != scopeMap.end()) - continue; - - // If the actually overwritten operand is in a different loop scope - // don't cache this value instead as it may require more memory - LoopContext lc1; - LoopContext lc2; - bool inLoop1 = - getContext(const_cast(inst)->getParent(), lc1); - bool inLoop2 = getContext(cast(op)->getParent(), lc2); - if (inLoop1 != inLoop2 || (inLoop1 && (lc1.header != lc2.header))) { - continue; - } - - // If a placeholder phi for inversion (and we know from above not - // recomputable) - if (!isa(op) && - dyn_cast_or_null(hasUninverted(op))) { - goto forceCache; - } - - // Even if cannot recompute (say a phi node), don't force a reload if it - // is possible to just use this instruction from forward pass without - // issue - if (auto i2 = dyn_cast(op)) { - if (!i2->mayReadOrWriteMemory()) { - LoopContext lc; - bool inLoop = const_cast(this)->getContext( - i2->getParent(), lc); - if (!inLoop) { - // TODO upgrade this to be all returns that this could enter from - BasicBlock *orig = isOriginal(i2->getParent()); - assert(orig); - bool legal = BlocksDominatingAllReturns.count(orig); - if (legal) { - continue; - } - } - } - } - forceCache:; - EmitWarning("ChosenCache", *inst, "Choosing to cache use ", *inst, - " due to ", *op); - return false; - } - } - } - - if (auto op = dyn_cast(val)) { - if (!op->mayReadOrWriteMemory() || isReadNone(op) || isNVLoad(op)) - return true; - switch (op->getIntrinsicID()) { - case Intrinsic::sin: - case Intrinsic::cos: - case Intrinsic::exp: -#if LLVM_VERSION_MAJOR >= 19 - case Intrinsic::tanh: - case Intrinsic::cosh: - case Intrinsic::sinh: -#endif - case Intrinsic::log: - return true; - default: - return false; - } - } - - if (auto ci = dyn_cast(val)) { - auto called = ci->getCalledFunction(); - auto n = getFuncNameFromCall(const_cast(ci)); - Intrinsic::ID ID = Intrinsic::not_intrinsic; - if ((called && called->hasFnAttribute("enzyme_shouldrecompute")) || - isMemFreeLibMFunction(n, &ID) || n == "lgamma_r" || n == "lgammaf_r" || - n == "lgammal_r" || n == "__lgamma_r_finite" || - n == "__lgammaf_r_finite" || n == "__lgammal_r_finite" || n == "tanh" || - n == "tanhf" || n == "__pow_finite" || - n == "julia.pointer_from_objref" || startsWith(n, "enzyme_wrapmpi$$") || - n == "omp_get_thread_num" || n == "omp_get_max_threads" || - startsWith(n, "_ZN4libm4math3log")) { - return true; - } - if (isPointerArithmeticInst(ci)) - return true; - } - - // cache a call, assuming its longer to run that - if (isa(val)) { - llvm::errs() << " caching call: " << *val << "\n"; - // cast(val)->getCalledFunction()->dump(); - return false; - } - - return true; -} - -MDNode *GradientUtils::getDerivativeAliasScope(const Value *origptr, - ssize_t newptr) { - origptr = getBaseObject(origptr); - - auto found = differentialAliasScopeDomains.find(origptr); - if (found == differentialAliasScopeDomains.end()) { - MDBuilder MDB(oldFunc->getContext()); - MDNode *scope = MDB.createAnonymousAliasScopeDomain( - (" diff: %" + origptr->getName()).str()); - // vec.first = scope; - // found = differentialAliasScope.find(origptr); - found = differentialAliasScopeDomains.insert(std::make_pair(origptr, scope)) - .first; - } - auto &mp = differentialAliasScope[origptr]; - auto found2 = mp.find(newptr); - if (found2 == mp.end()) { - MDBuilder MDB(oldFunc->getContext()); - std::string name; - if (newptr == -1) - name = "primal"; - else - name = "shadow_" + std::to_string(newptr); - found2 = mp.insert(std::make_pair(newptr, MDB.createAnonymousAliasScope( - found->second, name))) - .first; - } - return found2->second; -} - -GradientUtils *GradientUtils::CreateFromClone( - EnzymeLogic &Logic, bool runtimeActivity, unsigned width, Function *todiff, - TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, - DIFFE_TYPE retType, ArrayRef constant_args, bool returnUsed, - bool shadowReturnUsed, std::map &returnMapping, - bool omp) { - Function *oldFunc = todiff; - - // Since this is forward pass this should always return the tape (at index 0) - returnMapping[AugmentedStruct::Tape] = 0; - - int returnCount = 0; - - if (returnUsed) { - assert(!todiff->getReturnType()->isEmptyTy()); - assert(!todiff->getReturnType()->isVoidTy()); - returnMapping[AugmentedStruct::Return] = returnCount + 1; - ++returnCount; - } - - // We don't need to differentially return something that we know is not a - // pointer (or somehow needed for shadow analysis) - if (shadowReturnUsed) { - assert(!todiff->getReturnType()->isEmptyTy()); - assert(!todiff->getReturnType()->isVoidTy()); - returnMapping[AugmentedStruct::DifferentialReturn] = returnCount + 1; - ++returnCount; - } - - ReturnType returnValue; - if (returnCount == 0) - returnValue = ReturnType::Tape; - else if (returnCount == 1) - returnValue = ReturnType::TapeAndReturn; - else if (returnCount == 2) - returnValue = ReturnType::TapeAndTwoReturns; - else - llvm_unreachable("illegal number of elements in augmented return struct"); - - ValueToValueMapTy invertedPointers; - SmallPtrSet constants; - SmallPtrSet nonconstant; - SmallPtrSet returnvals; - llvm::ValueMap originalToNew; - - SmallPtrSet constant_values; - SmallPtrSet nonconstant_values; - - std::string prefix = "fakeaugmented"; - if (width > 1) - prefix += std::to_string(width); - prefix += "_"; - prefix += todiff->getName().str(); - - auto newFunc = Logic.PPC.CloneFunctionWithReturns( - DerivativeMode::ReverseModePrimal, width, oldFunc, invertedPointers, - constant_args, constant_values, nonconstant_values, returnvals, - /*returnValue*/ returnValue, retType, prefix, &originalToNew, - /*diffeReturnArg*/ false, /*additionalArg*/ nullptr); - - // Convert overwritten args from the input function to the preprocessed - // function - - FnTypeInfo typeInfo(oldFunc); - { - auto toarg = todiff->arg_begin(); - auto olarg = oldFunc->arg_begin(); - for (; toarg != todiff->arg_end(); ++toarg, ++olarg) { - - { - auto fd = oldTypeInfo.Arguments.find(toarg); - assert(fd != oldTypeInfo.Arguments.end()); - typeInfo.Arguments.insert( - std::pair(olarg, fd->second)); - } - - { - auto cfd = oldTypeInfo.KnownValues.find(toarg); - assert(cfd != oldTypeInfo.KnownValues.end()); - typeInfo.KnownValues.insert( - std::pair>(olarg, cfd->second)); - } - } - typeInfo.Return = oldTypeInfo.Return; - } - - TypeResults TR = TA.analyzeFunction(typeInfo); - if (!oldFunc->empty()) - assert(TR.getFunction() == oldFunc); - - auto res = new GradientUtils( - Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values, - nonconstant_values, retType, shadowReturnUsed, constant_args, - originalToNew, DerivativeMode::ReverseModePrimal, runtimeActivity, width, - omp); - return res; -} - -DIFFE_TYPE GradientUtils::getReturnDiffeType(llvm::Value *orig, - bool *primalReturnUsedP, - bool *shadowReturnUsedP) const { - return getReturnDiffeType(orig, primalReturnUsedP, shadowReturnUsedP, mode); -} - -DIFFE_TYPE GradientUtils::getReturnDiffeType(llvm::Value *orig, - bool *primalReturnUsedP, - bool *shadowReturnUsedP, - DerivativeMode cmode) const { - bool shadowReturnUsed = false; - - DIFFE_TYPE subretType; - if (isConstantValue(orig)) { - subretType = DIFFE_TYPE::CONSTANT; - } else { - if (cmode == DerivativeMode::ForwardMode || - cmode == DerivativeMode::ForwardModeError || - cmode == DerivativeMode::ForwardModeSplit) { - subretType = DIFFE_TYPE::DUP_ARG; - shadowReturnUsed = true; - } else { - if (!orig->getType()->isFPOrFPVectorTy() && TR.anyPointer(orig)) { - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Shadow>(this, orig, cmode, notForAnalysis)) { - subretType = DIFFE_TYPE::DUP_ARG; - shadowReturnUsed = true; - } else - subretType = DIFFE_TYPE::CONSTANT; - } else { - subretType = DIFFE_TYPE::OUT_DIFF; - } - } - } - - if (primalReturnUsedP) { - bool subretused = !unnecessaryValuesP || unnecessaryValuesP->find(orig) == - unnecessaryValuesP->end(); - auto found = knownRecomputeHeuristic.find(orig); - if (found != knownRecomputeHeuristic.end()) { - if (!found->second) { - subretused = true; - } - } - *primalReturnUsedP = subretused; - } - - if (shadowReturnUsedP) - *shadowReturnUsedP = shadowReturnUsed; - return subretType; -} - -DIFFE_TYPE GradientUtils::getDiffeType(Value *v, bool foreignFunction) const { - if (isConstantValue(v) && !foreignFunction) { - return DIFFE_TYPE::CONSTANT; - } - - auto argType = v->getType(); - - if (!argType->isFPOrFPVectorTy() && (TR.anyPointer(v) || foreignFunction)) { - if (argType->isPointerTy()) { - auto at = getBaseObject(v); - if (auto arg = dyn_cast(at)) { - if (ArgDiffeTypes[arg->getArgNo()] == DIFFE_TYPE::DUP_NONEED) { - return DIFFE_TYPE::DUP_NONEED; - } - } else if (isa(at) || isAllocationCall(at, TLI)) { - assert(unnecessaryValuesP); - if (unnecessaryValuesP->count(at)) - return DIFFE_TYPE::DUP_NONEED; - } - } - return DIFFE_TYPE::DUP_ARG; - } else { - if (foreignFunction) - assert(!argType->isIntOrIntVectorTy()); - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError || - mode == DerivativeMode::ForwardModeSplit) - return DIFFE_TYPE::DUP_ARG; - else - return DIFFE_TYPE::OUT_DIFF; - } -} - -Constant *GradientUtils::GetOrCreateShadowConstant( - RequestContext context, EnzymeLogic &Logic, TargetLibraryInfo &TLI, - TypeAnalysis &TA, Constant *oval, DerivativeMode mode, bool runtimeActivity, - unsigned width, bool AtomicAdd) { - if (isa(oval)) { - return oval; - } else if (isa(oval)) { - return oval; - } else if (isa(oval)) { - return oval; - } else if (auto CD = dyn_cast(oval)) { - SmallVector Vals; - for (size_t i = 0, len = CD->getNumElements(); i < len; i++) { - Vals.push_back(GetOrCreateShadowConstant( - context, Logic, TLI, TA, CD->getElementAsConstant(i), mode, - runtimeActivity, width, AtomicAdd)); - } - return ConstantArray::get(CD->getType(), Vals); - } else if (auto CD = dyn_cast(oval)) { - SmallVector Vals; - for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - Vals.push_back( - GetOrCreateShadowConstant(context, Logic, TLI, TA, CD->getOperand(i), - mode, runtimeActivity, width, AtomicAdd)); - } - return ConstantArray::get(CD->getType(), Vals); - } else if (auto CD = dyn_cast(oval)) { - SmallVector Vals; - for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - Vals.push_back( - GetOrCreateShadowConstant(context, Logic, TLI, TA, CD->getOperand(i), - mode, runtimeActivity, width, AtomicAdd)); - } - return ConstantStruct::get(CD->getType(), Vals); - } else if (auto CD = dyn_cast(oval)) { - SmallVector Vals; - for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - Vals.push_back( - GetOrCreateShadowConstant(context, Logic, TLI, TA, CD->getOperand(i), - mode, runtimeActivity, width, AtomicAdd)); - } - return ConstantVector::get(Vals); - } else if (auto F = dyn_cast(oval)) { - return GetOrCreateShadowFunction(context, Logic, TLI, TA, F, mode, - runtimeActivity, width, AtomicAdd); - } else if (auto arg = dyn_cast(oval)) { - auto C = - GetOrCreateShadowConstant(context, Logic, TLI, TA, arg->getOperand(0), - mode, runtimeActivity, width, AtomicAdd); - if (arg->isCast() || arg->getOpcode() == Instruction::GetElementPtr || - arg->getOpcode() == Instruction::Add) { - SmallVector NewOps; - for (unsigned i = 0, e = arg->getNumOperands(); i != e; ++i) - NewOps.push_back(i == 0 ? C : arg->getOperand(i)); - return arg->getWithOperands(NewOps); - } - } else if (auto arg = dyn_cast(oval)) { - return GetOrCreateShadowConstant(context, Logic, TLI, TA, arg->getAliasee(), - mode, runtimeActivity, width, AtomicAdd); - } else if (auto arg = dyn_cast(oval)) { - if (arg->getName() == "_ZTVN10__cxxabiv120__si_class_type_infoE" || - arg->getName() == "_ZTVN10__cxxabiv117__class_type_infoE" || - arg->getName() == "_ZTVN10__cxxabiv121__vmi_class_type_infoE" || - startsWith(arg->getName(), "??_R")) // any of the MS RTTI manglings - return arg; - - if (hasMetadata(arg, "enzyme_shadow")) { - auto md = arg->getMetadata("enzyme_shadow"); - if (!isa(md)) { - llvm::errs() << *arg << "\n"; - llvm::errs() << *md << "\n"; - assert(0 && "cannot compute with global variable that doesn't have " - "marked shadow global"); - report_fatal_error( - "cannot compute with global variable that doesn't " - "have marked shadow global (metadata incorrect type)"); - } - auto md2 = cast(md); - assert(md2->getNumOperands() == 1); - auto gvemd = cast(md2->getOperand(0)); - return gvemd->getValue(); - } - - auto Arch = llvm::Triple(arg->getParent()->getTargetTriple()).getArch(); - int SharedAddrSpace = Arch == Triple::amdgcn - ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local - : 3; - int AddrSpace = cast(arg->getType())->getAddressSpace(); - if ((Arch == Triple::nvptx || Arch == Triple::nvptx64 || - Arch == Triple::amdgcn) && - AddrSpace == SharedAddrSpace) { - assert(0 && "shared memory not handled in meta global"); - } - - // Create global variable locally if not externally visible - if (arg->isConstant() || arg->hasInternalLinkage() || - arg->hasPrivateLinkage() || - (arg->hasExternalLinkage() && arg->hasInitializer())) { - Type *type = arg->getValueType(); - auto shadow = new GlobalVariable( - *arg->getParent(), type, arg->isConstant(), arg->getLinkage(), - Constant::getNullValue(type), arg->getName() + "_shadow", arg, - arg->getThreadLocalMode(), arg->getType()->getAddressSpace(), - arg->isExternallyInitialized()); - arg->setMetadata("enzyme_shadow", - MDTuple::get(shadow->getContext(), - {ConstantAsMetadata::get(shadow)})); - shadow->setAlignment(arg->getAlign()); - shadow->setUnnamedAddr(arg->getUnnamedAddr()); - if (arg->hasInitializer()) - shadow->setInitializer(GetOrCreateShadowConstant( - context, Logic, TLI, TA, cast(arg->getOperand(0)), mode, - runtimeActivity, width, AtomicAdd)); - return shadow; - } - } - llvm::errs() << " unknown constant to create shadow of: " << *oval << "\n"; - llvm_unreachable("unknown constant to create shadow of"); -} - -Constant *GradientUtils::GetOrCreateShadowFunction( - RequestContext context, EnzymeLogic &Logic, TargetLibraryInfo &TLI, - TypeAnalysis &TA, Function *fn, DerivativeMode mode, bool runtimeActivity, - unsigned width, bool AtomicAdd) { - //! Todo allow tape propagation - // Note that specifically this should _not_ be called with topLevel=true - // (since it may not be valid to always assume we can recompute the - // augmented primal) However, in the absence of a way to pass tape data - // from an indirect augmented (and also since we dont presently allow - // indirect augmented calls), topLevel MUST be true otherwise subcalls will - // not be able to lookup the augmenteddata/subdata (triggering an assertion - // failure, among much worse) - bool isRealloc = false; - if (fn->empty()) { - if (hasMetadata(fn, "enzyme_callwrapper")) { - auto md = fn->getMetadata("enzyme_callwrapper"); - if (!isa(md)) { - llvm::errs() << *fn << "\n"; - llvm::errs() << *md << "\n"; - assert(0 && "callwrapper of incorrect type"); - report_fatal_error("callwrapper of incorrect type"); - } - auto md2 = cast(md); - assert(md2->getNumOperands() == 1); - auto gvemd = cast(md2->getOperand(0)); - fn = cast(gvemd->getValue()); - } else { - auto oldfn = fn; - fn = Function::Create(oldfn->getFunctionType(), Function::InternalLinkage, - "callwrap_" + oldfn->getName(), oldfn->getParent()); - BasicBlock *entry = BasicBlock::Create(fn->getContext(), "entry", fn); - IRBuilder<> B(entry); - SmallVector args; - for (auto &a : fn->args()) - args.push_back(&a); - auto res = B.CreateCall(oldfn, args); - if (fn->getReturnType()->isVoidTy()) - B.CreateRetVoid(); - else - B.CreateRet(res); - oldfn->setMetadata( - "enzyme_callwrapper", - MDTuple::get(oldfn->getContext(), {ConstantAsMetadata::get(fn)})); - if (oldfn->getName() == "realloc") - isRealloc = true; - } - } - - bool subsequent_calls_may_write = mode != DerivativeMode::ForwardMode && - mode != DerivativeMode::ForwardModeError; - std::vector overwritten_args; - FnTypeInfo type_args(fn); - if (isRealloc) { - llvm::errs() << "warning: assuming realloc only creates pointers\n"; - type_args.Return.insert({-1, -1}, BaseType::Pointer); - } - - // conservatively assume that we can only cache existing floating types - // (i.e. that all args are overwritten) - std::vector types; - for (auto &a : fn->args()) { - overwritten_args.push_back(!a.getType()->isFPOrFPVectorTy()); - TypeTree TT; - if (a.getType()->isFPOrFPVectorTy()) - TT.insert({-1}, ConcreteType(a.getType()->getScalarType())); - type_args.Arguments.insert(std::pair(&a, TT)); - type_args.KnownValues.insert( - std::pair>(&a, {})); - DIFFE_TYPE typ; - if (a.getType()->isFPOrFPVectorTy()) { - typ = (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError) - ? DIFFE_TYPE::DUP_ARG - : DIFFE_TYPE::OUT_DIFF; - } else if (a.getType()->isIntegerTy() && - cast(a.getType())->getBitWidth() < 16) { - typ = DIFFE_TYPE::CONSTANT; - } else if (a.getType()->isVoidTy() || a.getType()->isEmptyTy()) { - typ = DIFFE_TYPE::CONSTANT; - } else { - typ = DIFFE_TYPE::DUP_ARG; - } - types.push_back(typ); - } - - DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy() && - mode != DerivativeMode::ForwardMode && - mode != DerivativeMode::ForwardModeError - ? DIFFE_TYPE::OUT_DIFF - : DIFFE_TYPE::DUP_ARG; - - if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() || - (fn->getReturnType()->isIntegerTy() && - cast(fn->getReturnType())->getBitWidth() < 16)) - retType = DIFFE_TYPE::CONSTANT; - - if (mode != DerivativeMode::ForwardMode && - mode != DerivativeMode::ForwardModeError && - retType == DIFFE_TYPE::DUP_ARG) { - if (auto ST = dyn_cast(fn->getReturnType())) { - size_t numflt = 0; - - for (unsigned i = 0; i < ST->getNumElements(); ++i) { - auto midTy = ST->getElementType(i); - if (midTy->isFPOrFPVectorTy()) - numflt++; - } - if (numflt == ST->getNumElements()) - retType = DIFFE_TYPE::OUT_DIFF; - } - } - - switch (mode) { - case DerivativeMode::ForwardModeError: - case DerivativeMode::ForwardMode: { - Constant *newf = Logic.CreateForwardDiff( - context, fn, retType, types, TA, false, mode, /*freeMemory*/ true, - runtimeActivity, width, nullptr, type_args, subsequent_calls_may_write, - overwritten_args, - /*augmented*/ nullptr); - - assert(newf); - - std::string prefix = (mode == DerivativeMode::ForwardMode) - ? "_enzyme_forward" - : "_enzyme_forwarderror"; - - if (width > 1) { - prefix += std::to_string(width); - } - - std::string globalname = (prefix + "_" + fn->getName() + "'").str(); - auto GV = fn->getParent()->getNamedValue(globalname); - - if (GV == nullptr) { - GV = new GlobalVariable(*fn->getParent(), newf->getType(), true, - GlobalValue::LinkageTypes::InternalLinkage, newf, - globalname); - } - - return ConstantExpr::getPointerCast(GV, fn->getType()); - } - case DerivativeMode::ForwardModeSplit: { - auto &augdata = Logic.CreateAugmentedPrimal( - context, fn, retType, /*constant_args*/ types, TA, - /*returnUsed*/ !fn->getReturnType()->isEmptyTy() && - !fn->getReturnType()->isVoidTy(), - /*shadowReturnUsed*/ false, type_args, subsequent_calls_may_write, - overwritten_args, - /*forceAnonymousTape*/ true, runtimeActivity, width, AtomicAdd); - Constant *newf = Logic.CreateForwardDiff( - context, fn, retType, types, TA, false, mode, /*freeMemory*/ true, - runtimeActivity, width, nullptr, type_args, subsequent_calls_may_write, - overwritten_args, - /*augmented*/ &augdata); - - assert(newf); - - std::string prefix = "_enzyme_forwardsplit"; - - if (width > 1) { - prefix += std::to_string(width); - } - - auto cdata = ConstantStruct::get( - StructType::get(newf->getContext(), - {augdata.fn->getType(), newf->getType()}), - {augdata.fn, newf}); - - std::string globalname = (prefix + "_" + fn->getName() + "'").str(); - auto GV = fn->getParent()->getNamedValue(globalname); - - if (GV == nullptr) { - GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true, - GlobalValue::LinkageTypes::InternalLinkage, cdata, - globalname); - } - - return ConstantExpr::getPointerCast(GV, fn->getType()); - } - case DerivativeMode::ReverseModeCombined: - case DerivativeMode::ReverseModeGradient: - case DerivativeMode::ReverseModePrimal: { - // TODO re atomic add consider forcing it to be atomic always as fallback if - // used in a parallel context - bool returnUsed = - !fn->getReturnType()->isEmptyTy() && !fn->getReturnType()->isVoidTy(); - bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG || - retType == DIFFE_TYPE::DUP_NONEED); - auto &augdata = Logic.CreateAugmentedPrimal( - context, fn, retType, /*constant_args*/ types, TA, returnUsed, - shadowReturnUsed, type_args, subsequent_calls_may_write, - overwritten_args, - /*forceAnonymousTape*/ true, runtimeActivity, width, AtomicAdd); - Constant *newf = Logic.CreatePrimalAndGradient( - context, - (ReverseCacheKey){.todiff = fn, - .retType = retType, - .constant_args = types, - .subsequent_calls_may_write = - subsequent_calls_may_write, - .overwritten_args = overwritten_args, - .returnUsed = false, - .shadowReturnUsed = false, - .mode = DerivativeMode::ReverseModeGradient, - .width = width, - .freeMemory = true, - .AtomicAdd = AtomicAdd, - .additionalType = getInt8PtrTy(fn->getContext()), - .forceAnonymousTape = true, - .typeInfo = type_args, - .runtimeActivity = runtimeActivity}, - TA, - /*map*/ &augdata); - assert(newf); - auto cdata = ConstantStruct::get( - StructType::get(newf->getContext(), - {augdata.fn->getType(), newf->getType()}), - {augdata.fn, newf}); - std::string globalname = ("_enzyme_reverse_" + fn->getName() + "'").str(); - auto GV = fn->getParent()->getNamedValue(globalname); - - if (GV == nullptr) { - GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true, - GlobalValue::LinkageTypes::InternalLinkage, cdata, - globalname); - } - return ConstantExpr::getPointerCast(GV, fn->getType()); - } - } - llvm_unreachable("Illegal state: unknown mode for GetOrCreateShadowFunction"); -} - -void GradientUtils::getReverseBuilder(IRBuilder<> &Builder2, bool original) { - assert(reverseBlocks.size()); - BasicBlock *BB = Builder2.GetInsertBlock(); - if (original) - BB = getNewFromOriginal(BB); - assert(reverseBlocks.find(BB) != reverseBlocks.end()); - BasicBlock *BB2 = reverseBlocks[BB].back(); - if (!BB2) { - llvm::errs() << "oldFunc: " << oldFunc << "\n"; - llvm::errs() << "newFunc: " << newFunc << "\n"; - llvm::errs() << "could not invert " << *BB; - } - assert(BB2); - - if (BB2->getTerminator()) - Builder2.SetInsertPoint(BB2->getTerminator()); - else - Builder2.SetInsertPoint(BB2); - Builder2.SetCurrentDebugLocation( - getNewFromOriginal(Builder2.getCurrentDebugLocation())); - Builder2.setFastMathFlags(getFast()); -} - -void GradientUtils::getForwardBuilder(IRBuilder<> &Builder2) { - Instruction *insert = &*Builder2.GetInsertPoint(); - Instruction *nInsert = getNewFromOriginal(insert); - - assert(nInsert); - - Builder2.SetInsertPoint(getNextNonDebugInstruction(nInsert)); - Builder2.SetCurrentDebugLocation( - getNewFromOriginal(Builder2.getCurrentDebugLocation())); - Builder2.setFastMathFlags(getFast()); -} - -void GradientUtils::setPtrDiffe(Instruction *orig, Value *ptr, Value *newval, - IRBuilder<> &BuilderM, MaybeAlign align, - unsigned start, unsigned size, bool isVolatile, - AtomicOrdering ordering, - SyncScope::ID syncScope, Value *mask, - ArrayRef noAlias, - ArrayRef scopes) { -#ifndef NDEBUG - if (auto inst = dyn_cast(ptr)) { - assert(inst->getParent()->getParent() == oldFunc); - } - if (auto arg = dyn_cast(ptr)) { - assert(arg->getParent() == oldFunc); - } -#endif - - Value *origptr = ptr; - - ptr = invertPointerM(ptr, BuilderM); - if (!isOriginalBlock(*BuilderM.GetInsertBlock()) && - mode != DerivativeMode::ForwardMode && - mode != DerivativeMode::ForwardModeError) - ptr = lookupM(ptr, BuilderM); - - if (mask && !isOriginalBlock(*BuilderM.GetInsertBlock()) && - mode != DerivativeMode::ForwardMode && - mode != DerivativeMode::ForwardModeError) - mask = lookupM(mask, BuilderM); - - size_t idx = 0; - - auto &DL = oldFunc->getParent()->getDataLayout(); - - auto rule = [&](Value *ptr, Value *newval) { - auto storeSize = (DL.getTypeSizeInBits(newval->getType()) + 7) / 8; - if (!mask) { - - if (size != storeSize) { - IRBuilder<> A(inversionAllocs); - Value *valptr = A.CreateAlloca(newval->getType()); - BuilderM.CreateStore(newval, valptr); - - auto i8 = Type::getInt8Ty(ptr->getContext()); - - if (start != 0) { - ptr = BuilderM.CreatePointerCast( - ptr, - PointerType::get( - i8, cast(ptr->getType())->getAddressSpace())); - auto off = - ConstantInt::get(Type::getInt64Ty(ptr->getContext()), start); - ptr = BuilderM.CreateInBoundsGEP(i8, ptr, off); - - valptr = BuilderM.CreatePointerCast( - valptr, - PointerType::get( - i8, cast(valptr->getType())->getAddressSpace())); - valptr = BuilderM.CreateInBoundsGEP(i8, valptr, off); - } - - Type *ty = nullptr; - - if (size == 8) - ty = BuilderM.getInt64Ty(); - else if (size % 8 == 0) - ty = ArrayType::get(BuilderM.getInt64Ty(), size); - else if (size == 4) - ty = BuilderM.getInt32Ty(); - else if (size % 4 == 0) - ty = ArrayType::get(BuilderM.getInt32Ty(), size); - else - ty = ArrayType::get(i8, size); - - ptr = BuilderM.CreatePointerCast( - ptr, PointerType::get( - ty, cast(ptr->getType())->getAddressSpace())); - valptr = BuilderM.CreatePointerCast( - valptr, - PointerType::get( - ty, cast(valptr->getType())->getAddressSpace())); - newval = BuilderM.CreateLoad(ty, valptr); - } - - auto ts = BuilderM.CreateStore(newval, ptr); - if (align) - ts->setAlignment(*align); - - ts->setVolatile(isVolatile); - ts->setOrdering(ordering); - ts->setSyncScopeID(syncScope); - SmallVector scopeMD = { - getDerivativeAliasScope(origptr, idx)}; - for (auto M : scopes) - scopeMD.push_back(M); - auto scope = MDNode::get(ts->getContext(), scopeMD); - ts->setMetadata(LLVMContext::MD_alias_scope, scope); - - if (start == 0 && size == storeSize) { - ts->setMetadata(LLVMContext::MD_tbaa, - orig->getMetadata(LLVMContext::MD_tbaa)); - ts->setMetadata(LLVMContext::MD_tbaa_struct, - orig->getMetadata(LLVMContext::MD_tbaa_struct)); - } - ts->setDebugLoc(getNewFromOriginal(orig->getDebugLoc())); - - SmallVector MDs; - for (ssize_t j = -1; j < getWidth(); j++) { - if (j != (ssize_t)idx) - MDs.push_back(getDerivativeAliasScope(origptr, j)); - } - for (auto M : noAlias) - MDs.push_back(M); - if (MDs.size()) { - auto noscope = MDNode::get(ptr->getContext(), MDs); - ts->setMetadata(LLVMContext::MD_noalias, noscope); - } - } else { - assert(start == 0 && size == storeSize); - Type *tys[] = {newval->getType(), ptr->getType()}; - auto F = getIntrinsicDeclaration(oldFunc->getParent(), - Intrinsic::masked_store, tys); - assert(align); - Value *alignv = - ConstantInt::get(Type::getInt32Ty(ptr->getContext()), align->value()); - Value *args[] = {newval, ptr, alignv, mask}; - auto ts = BuilderM.CreateCall(F, args); - ts->setCallingConv(F->getCallingConv()); - ts->setMetadata(LLVMContext::MD_tbaa, - orig->getMetadata(LLVMContext::MD_tbaa)); - ts->setMetadata(LLVMContext::MD_tbaa_struct, - orig->getMetadata(LLVMContext::MD_tbaa_struct)); - ts->setDebugLoc(getNewFromOriginal(orig->getDebugLoc())); - } - idx++; - }; - - applyChainRule(BuilderM, rule, ptr, newval); -} - -Type *GradientUtils::getShadowType(Type *ty, unsigned width) { - if (width > 1) { - if (ty->isVoidTy()) - return ty; - return ArrayType::get(ty, width); - } else { - return ty; - } -} - -Type *GradientUtils::getShadowType(Type *ty) { - return getShadowType(ty, width); -} - -Type *GradientUtils::extractMeta(Type *T, ArrayRef off) { - for (auto idx : off) { - if (auto AT = dyn_cast(T)) { - T = AT->getElementType(); - continue; - } - if (auto ST = dyn_cast(T)) { - T = ST->getElementType(idx); - continue; - } - assert(false && "could not sub index into type"); - } - return T; -} - -Value *GradientUtils::extractMeta(IRBuilder<> &Builder, Value *Agg, - unsigned off, const Twine &name) { - return extractMeta(Builder, Agg, ArrayRef({off}), name); -} - -Value *GradientUtils::extractMeta(IRBuilder<> &Builder, Value *Agg, - ArrayRef off_init, - const Twine &name, bool fallback) { - std::vector off(off_init.begin(), off_init.end()); - while (off.size() != 0) { - if (auto Ins = dyn_cast(Agg)) { - size_t until = Ins->getNumIndices(); - if (off.size() < until) - until = off.size(); - bool subset = true; - for (size_t i = 0; i < until; i++) { - if (Ins->getIndices()[i] != off[i]) { - subset = false; - break; - } - } - if (!subset) { - Agg = Ins->getAggregateOperand(); - continue; - } else if (until < Ins->getNumIndices()) { - break; - } else { - off.erase(off.begin(), off.begin() + until); - Agg = Ins->getInsertedValueOperand(); - continue; - } - } - if (auto ext = dyn_cast(Agg)) { - off.insert(off.begin(), ext->getIndices().begin(), - ext->getIndices().end()); - Agg = ext->getAggregateOperand(); - continue; - } - if (auto CA = dyn_cast(Agg)) { - Agg = CA->getElementValue(off[0]); - off.erase(off.begin(), off.begin() + 1); - } - break; - } - if (off.size() == 0) - return Agg; - - if (!fallback) - return nullptr; - - if (Agg->getType()->isVectorTy() && off.size() == 1) - return Builder.CreateExtractElement(Agg, off[0], name); - - return Builder.CreateExtractValue(Agg, off, name); -} - -llvm::Value *GradientUtils::recursiveFAdd(llvm::IRBuilder<> &B, - llvm::Value *lhs, llvm::Value *rhs, - llvm::ArrayRef lhs_off, - llvm::ArrayRef rhs_off, - llvm::Value *prev, bool vectorLayer) { - llvm::Type *lhs_ty = lhs->getType(); - if (!vectorLayer) { - for (auto idx : lhs_off) - lhs_ty = getSubType(lhs_ty, idx); - llvm::Type *rhs_ty = rhs->getType(); - for (auto idx : rhs_off) - rhs_ty = getSubType(rhs_ty, idx); - assert(lhs_ty == rhs_ty); - } - if (lhs_ty->isFPOrFPVectorTy()) { - if (lhs_off.size()) - lhs = extractMeta(B, lhs, lhs_off); - if (rhs_off.size()) - rhs = extractMeta(B, rhs, rhs_off); - llvm::Value *res = nullptr; - if (auto fp = llvm::dyn_cast(lhs)) { - if (fp->isZero()) - res = rhs; - } - if (auto fp = llvm::dyn_cast(rhs)) { - if (fp->isZero()) - res = lhs; - } - if (!res) { - if (auto *FPMO = dyn_cast(rhs)) - if (FPMO->getOpcode() == Instruction::FNeg) { - res = B.CreateFSub(lhs, FPMO->getOperand(0)); - } - } - if (!res) { - if (auto *S = dyn_cast(rhs)) { - if (S->getOpcode() == Instruction::FSub) { - if (auto C = dyn_cast(S->getOperand(0))) - if (C->isZero()) - res = B.CreateFSub(lhs, S->getOperand(1)); - } - } - } - if (!res) { - res = B.CreateFAdd(lhs, rhs); - } - if (lhs_off.size()) { - assert(prev); - res = B.CreateInsertValue(prev, res, lhs_off); - } - return res; - } else if (isa(lhs_ty) || isa(lhs_ty)) { - if (prev == nullptr) - prev = llvm::UndefValue::get(lhs_ty); - - size_t size; - if (auto AT = dyn_cast(lhs_ty)) - size = AT->getNumElements(); - else - size = cast(lhs_ty)->getNumElements(); - - for (size_t i = 0; i < size; ++i) { - llvm::SmallVector nlhs_off(lhs_off.begin(), lhs_off.end()); - if (vectorLayer) - nlhs_off.insert(nlhs_off.begin(), i); - else - nlhs_off.push_back(i); - llvm::SmallVector nrhs_off(rhs_off.begin(), rhs_off.end()); - if (vectorLayer) - nrhs_off.insert(nrhs_off.begin(), i); - else - nrhs_off.push_back(i); - prev = recursiveFAdd(B, lhs, rhs, nlhs_off, nrhs_off, prev); - } - return prev; - } - llvm_unreachable("Unknown type to recursively accumulate"); -} - -Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, - bool nullShadow) { - assert(oval); -#ifndef NDEBUG - if (auto inst = dyn_cast(oval)) { - assert(inst->getParent()->getParent() == oldFunc); - } - if (auto arg = dyn_cast(oval)) { - assert(arg->getParent() == oldFunc); - } -#endif - - if (isa(oval)) { - return applyChainRule(oval->getType(), BuilderM, [&]() { return oval; }); - } else if (isa(oval)) { - if (nullShadow) - return Constant::getNullValue(getShadowType(oval->getType())); - return applyChainRule(oval->getType(), BuilderM, [&]() { return oval; }); - } else if (isa(oval)) { - if (nullShadow) - return Constant::getNullValue(getShadowType(oval->getType())); - return applyChainRule(oval->getType(), BuilderM, [&]() { return oval; }); - } else if (auto CD = dyn_cast(oval)) { - SmallVector Vals; - for (size_t i = 0, len = CD->getNumElements(); i < len; i++) { - Value *val = - invertPointerM(CD->getElementAsConstant(i), BuilderM, nullShadow); - Vals.push_back(cast(val)); - } - auto rule = [&CD](ArrayRef Vals) { - return ConstantArray::get(CD->getType(), Vals); - }; - return applyChainRule(CD->getType(), Vals, BuilderM, rule); - } else if (auto CD = dyn_cast(oval)) { - SmallVector Vals; - for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - Value *val = invertPointerM(CD->getOperand(i), BuilderM, nullShadow); - Vals.push_back(cast(val)); - } - - auto rule = [&CD](ArrayRef Vals) { - return ConstantArray::get(CD->getType(), Vals); - }; - - return applyChainRule(CD->getType(), Vals, BuilderM, rule); - } else if (auto CD = dyn_cast(oval)) { - SmallVector Vals; - for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - Vals.push_back(cast( - invertPointerM(CD->getOperand(i), BuilderM, nullShadow))); - } - - auto rule = [&CD](ArrayRef Vals) { - return ConstantStruct::get(CD->getType(), Vals); - }; - return applyChainRule(CD->getType(), Vals, BuilderM, rule); - } else if (auto CD = dyn_cast(oval)) { - SmallVector Vals; - for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - Vals.push_back(cast( - invertPointerM(CD->getOperand(i), BuilderM, nullShadow))); - } - - auto rule = [](ArrayRef Vals) { - return ConstantVector::get(Vals); - }; - - return applyChainRule(CD->getType(), Vals, BuilderM, rule); - } else if (isa(oval) && nullShadow) { - auto rule = [&oval]() { return Constant::getNullValue(oval->getType()); }; - - return applyChainRule(oval->getType(), BuilderM, rule); - } - - bool shouldNullShadow = isConstantValue(oval); - if (shouldNullShadow) { - if (isa(oval) || isa(oval) || - isa(oval) || isa(oval)) { - shouldNullShadow = false; - auto orig = cast(oval); - if (knownRecomputeHeuristic.count(orig)) { - if (!knownRecomputeHeuristic[orig]) { - shouldNullShadow = true; - } - } - } - } - - if (shouldNullShadow) { - // NOTE, this is legal and the correct resolution, however, our activity - // analysis honeypot no longer exists - - // Nulling the shadow for a constant is only necessary if any of the data - // could contain a float (e.g. should not be applied to pointers). - if (nullShadow) { - auto ty = TR.query(oval); - auto &dl = newFunc->getParent()->getDataLayout(); - size_t size = (dl.getTypeSizeInBits(oval->getType()) + 7) / 8; - auto CT = ty[{-1}]; - bool couldContainFloat = CT.isFloat(); - bool allFloat = CT.isFloat(); - if (!CT.isKnown()) { - size_t i = 0; - for (; i < size;) { - auto CT2 = ty[{(int)i}]; - if (CT2.isFloat() || !CT2.isKnown()) { - couldContainFloat = true; - break; - } - if (CT2 == BaseType::Pointer) { - i += dl.getPointerSizeInBits() / 8; - continue; - } - i++; - } - } - if (couldContainFloat) { - if (allFloat) - return Constant::getNullValue(getShadowType(oval->getType())); - else { - IRBuilder<> bb(inversionAllocs); - if (auto arg = dyn_cast(oval)) { - arg = getNewFromOriginal(arg); - // Go one after since otherwise we won't be able - // to use in the store. - arg = arg->getNextNode(); - while (auto PN = dyn_cast(arg)) { - if (PN->getNumIncomingValues() == 0) - break; - arg = PN->getNextNode(); - } - bb.SetInsertPoint(arg); - } - auto alloc = bb.CreateAlloca(oval->getType()); - auto AT = ArrayType::get(bb.getInt8Ty(), size); - bb.CreateStore(getNewFromOriginal(oval), alloc); - Value *cur = bb.CreatePointerCast(alloc, PointerType::getUnqual(AT)); - size_t i = 0; - assert(size > 0); - for (; i < size;) { - auto CT2 = ty[{(int)i}]; - if (CT2 == BaseType::Pointer) { - i += dl.getPointerSizeInBits() / 8; - continue; - } else if (auto flt = CT2.isFloat()) { - auto ptr = bb.CreateConstInBoundsGEP2_32(AT, cur, 0, i); - ptr = bb.CreatePointerCast(ptr, PointerType::getUnqual(flt)); - bb.CreateStore(Constant::getNullValue(flt), ptr); - size_t chunk = dl.getTypeSizeInBits(flt) / 8; - i += chunk; - } else if (CT2 != BaseType::Integer) { - auto ptr = bb.CreateConstInBoundsGEP2_32(AT, cur, 0, i); - bb.CreateStore(Constant::getNullValue(bb.getInt8Ty()), ptr); - i++; - } else { - i++; - } - } - auto res = bb.CreateLoad(oval->getType(), alloc); - auto rule = [&res]() { return res; }; - auto res2 = applyChainRule(oval->getType(), BuilderM, rule); - invertedPointers.insert(std::make_pair( - (const Value *)oval, InvertedPointerVH(this, res2))); - return res2; - } - } - } - - if (isa(oval) || isa(oval)) { - auto rule = [&oval]() { return oval; }; - return applyChainRule(oval->getType(), BuilderM, rule); - } - - Value *newval = getNewFromOriginal(oval); - - auto rule = [&]() { return newval; }; - - return applyChainRule(oval->getType(), BuilderM, rule); - } - - auto M = oldFunc->getParent(); - assert(oval); - - { - auto ifound = invertedPointers.find(oval); - if (ifound != invertedPointers.end()) { - return &*ifound->second; - } - } - - if (mode != DerivativeMode::ForwardMode && - mode != DerivativeMode::ForwardModeError && - mode != DerivativeMode::ForwardModeSplit && nullShadow) { - auto CT = TR.query(oval)[{-1}]; - if (CT.isFloat()) { - return Constant::getNullValue(getShadowType(oval->getType())); - } - } - - if (isa(oval) && !TR.anyPointer(oval)) { - return Constant::getNullValue(getShadowType(oval->getType())); - } else if (isa(oval) && cast(oval)->hasByValAttr()) { - IRBuilder<> bb(inversionAllocs); - - Type *subType = nullptr; - auto attr = cast(oval)->getAttribute(Attribute::ByVal); - subType = attr.getValueAsType(); - - auto rule1 = [&]() { - AllocaInst *antialloca = bb.CreateAlloca( - subType, cast(oval->getType())->getPointerAddressSpace(), - nullptr, oval->getName() + "'ipa"); - - auto dst_arg = - bb.CreateBitCast(antialloca, getInt8PtrTy(oval->getContext())); - auto val_arg = ConstantInt::get(Type::getInt8Ty(oval->getContext()), 0); - auto len_arg = ConstantInt::get( - Type::getInt64Ty(oval->getContext()), - M->getDataLayout().getTypeAllocSizeInBits(subType) / 8); - auto volatile_arg = ConstantInt::getFalse(oval->getContext()); - - Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg}; - Type *tys[] = {dst_arg->getType(), len_arg->getType()}; - bb.CreateCall(getIntrinsicDeclaration(M, Intrinsic::memset, tys), args); - - return antialloca; - }; - - Value *antialloca = applyChainRule(oval->getType(), bb, rule1); - - invertedPointers.insert(std::make_pair( - (const Value *)oval, InvertedPointerVH(this, antialloca))); - - return antialloca; - } else if (auto arg = dyn_cast(oval)) { - Value *aliasTarget = arg->getAliasee(); - return invertPointerM(aliasTarget, BuilderM, nullShadow); - } else if (auto arg = dyn_cast(oval)) { - if (!hasMetadata(arg, "enzyme_shadow")) { - - if ((mode == DerivativeMode::ReverseModeCombined || - mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeError) && - arg->getType()->getPointerAddressSpace() == 0) { - auto CT = TR.query(arg)[{-1, -1}]; - // Can only localy replace a global variable if it is - // known not to contain a pointer, which may be initialized - // outside of this function to contain other memory which - // will not have a shadow within the current function. - if (CT.isKnown() && CT != BaseType::Pointer) { - bool seen = false; - MemoryLocation -#if LLVM_VERSION_MAJOR >= 12 - Loc = MemoryLocation(oval, LocationSize::beforeOrAfterPointer()); -#else - Loc = MemoryLocation(oval, LocationSize::unknown()); -#endif - for (CallInst *CI : originalCalls) { - if (isa(CI)) - continue; - if (!isConstantInstruction(CI)) { - auto F = getFunctionFromCall(CI); - if (F && isMemFreeLibMFunction(F->getName())) { - continue; - } - if (llvm::isModOrRefSet(OrigAA->getModRefInfo(CI, Loc))) { - seen = true; - llvm::errs() << " cannot shadow-inline global " << *oval - << " due to " << *CI << "\n"; - goto endCheck; - } - } - } - endCheck:; - if (!seen) { - IRBuilder<> bb(inversionAllocs); - Type *allocaTy = arg->getValueType(); - - auto rule1 = [&]() { - AllocaInst *antialloca = bb.CreateAlloca( - allocaTy, arg->getType()->getPointerAddressSpace(), nullptr, - arg->getName() + "'ipa"); - if (arg->getAlignment()) { - antialloca->setAlignment(Align(arg->getAlignment())); - } - return antialloca; - }; - - Value *antialloca = applyChainRule(arg->getType(), bb, rule1); - - invertedPointers.insert(std::make_pair( - (const Value *)oval, InvertedPointerVH(this, antialloca))); - - auto rule2 = [&](Value *antialloca) { - auto dst_arg = - bb.CreateBitCast(antialloca, getInt8PtrTy(arg->getContext())); - auto val_arg = - ConstantInt::get(Type::getInt8Ty(arg->getContext()), 0); - auto len_arg = - ConstantInt::get(Type::getInt64Ty(arg->getContext()), - M->getDataLayout().getTypeAllocSizeInBits( - arg->getValueType()) / - 8); - auto volatile_arg = ConstantInt::getFalse(oval->getContext()); - - Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg}; - Type *tys[] = {dst_arg->getType(), len_arg->getType()}; - auto memset = cast(bb.CreateCall( - getIntrinsicDeclaration(M, Intrinsic::memset, tys), args)); - if (arg->getAlignment()) { - memset->addParamAttr( - 0, Attribute::getWithAlignment(arg->getContext(), - Align(arg->getAlignment()))); - } - memset->addParamAttr(0, Attribute::NonNull); - assert((width > 1 && antialloca->getType() == - ArrayType::get(arg->getType(), width)) || - antialloca->getType() == arg->getType()); - return antialloca; - }; - - return applyChainRule(arg->getType(), bb, rule2, antialloca); - } - } - } - - auto Arch = - llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); - int SharedAddrSpace = - Arch == Triple::amdgcn - ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local - : 3; - int AddrSpace = cast(arg->getType())->getAddressSpace(); - if ((Arch == Triple::nvptx || Arch == Triple::nvptx64 || - Arch == Triple::amdgcn) && - AddrSpace == SharedAddrSpace) { - llvm::errs() << "warning found shared memory\n"; - Type *type = arg->getValueType(); - // TODO this needs initialization by entry - auto shadow = new GlobalVariable( - *arg->getParent(), type, arg->isConstant(), arg->getLinkage(), - UndefValue::get(type), arg->getName() + "_shadow", arg, - arg->getThreadLocalMode(), arg->getType()->getAddressSpace(), - arg->isExternallyInitialized()); - arg->setMetadata("enzyme_shadow", - MDTuple::get(shadow->getContext(), - {ConstantAsMetadata::get(shadow)})); - shadow->setMetadata("enzyme_internalshadowglobal", - MDTuple::get(shadow->getContext(), {})); - shadow->setAlignment(arg->getAlign()); - shadow->setUnnamedAddr(arg->getUnnamedAddr()); - return shadow; - } - - // Create global variable locally if not externally visible - // If a variable is constant, for forward mode it will also - // only be read, so invert initializing is fine. - // For reverse mode, any floats will be +='d into, but never - // read, and any pointers will be used as expected. The never - // read means even if two globals for floats, that's fine. - // As long as the pointers point to equivalent places (which - // they should from the same initialization), it is also ok. - if (arg->hasInternalLinkage() || arg->hasPrivateLinkage() || - (arg->hasExternalLinkage() && arg->hasInitializer()) || - arg->isConstant()) { - Type *elemTy = arg->getValueType(); - IRBuilder<> B(inversionAllocs); - - auto rule = [&]() { - auto shadow = new GlobalVariable( - *arg->getParent(), elemTy, arg->isConstant(), arg->getLinkage(), - Constant::getNullValue(elemTy), arg->getName() + "_shadow", arg, - arg->getThreadLocalMode(), arg->getType()->getAddressSpace(), - arg->isExternallyInitialized()); - shadow->setAlignment(arg->getAlign()); - shadow->setUnnamedAddr(arg->getUnnamedAddr()); - - return shadow; - }; - - Value *shadow = applyChainRule(oval->getType(), BuilderM, rule); - arg->setMetadata( - "enzyme_shadow", - MDTuple::get(shadow->getContext(), - {ConstantAsMetadata::get(cast(shadow))})); - if (getWidth() != 1) { - BuilderM.Insert(InsertValueInst::Create(shadow, arg, {0}), "tmp"); - } - - if (arg->hasInitializer()) { - applyChainRule( - BuilderM, - [&](Value *shadow, Value *ip) { - cast(shadow)->setInitializer( - cast(ip)); - }, - shadow, - invertPointerM(arg->getInitializer(), B, /*nullShadow*/ true)); - } - - invertedPointers.insert(std::make_pair( - (const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } - - std::string s; - llvm::raw_string_ostream ss(s); - ss << "cannot compute with global variable that doesn't have marked " - "shadow global\n"; - ss << *arg << "\n"; - if (CustomErrorHandler) { - return unwrap(CustomErrorHandler(ss.str().c_str(), wrap(arg), - ErrorType::NoShadow, this, nullptr, - wrap(&BuilderM))); - } else { - EmitFailure("InvertGlobal", BuilderM.getCurrentDebugLocation(), oldFunc, - ss.str()); - } - return UndefValue::get(getShadowType(arg->getType())); - } - auto md = arg->getMetadata("enzyme_shadow"); - if (!isa(md)) { - llvm::errs() << *arg << "\n"; - llvm::errs() << *md << "\n"; - std::string s; - llvm::raw_string_ostream ss(s); - ss << "cannot compute with global variable that doesn't have marked " - "shadow global as mdtuple\n"; - ss << *arg << "\n"; - ss << " md: " << *md << "\n"; - if (CustomErrorHandler) { - return unwrap(CustomErrorHandler(ss.str().c_str(), wrap(arg), - ErrorType::NoShadow, this, nullptr, - wrap(&BuilderM))); - } else { - EmitFailure("InvertGlobal", BuilderM.getCurrentDebugLocation(), oldFunc, - ss.str()); - } - return UndefValue::get(getShadowType(arg->getType())); - } - auto md2 = cast(md); - assert(md2->getNumOperands() == 1); - auto gvemd = cast(md2->getOperand(0)); - auto cs = cast(gvemd->getValue()); - - if (width > 1) { - SmallVector Vals; - for (unsigned i = 0; i < width; ++i) { - - Constant *idxs[] = { - ConstantInt::get(Type::getInt32Ty(cs->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(cs->getContext()), i)}; - Constant *elem = ConstantExpr::getInBoundsGetElementPtr( - getShadowType(arg->getValueType()), cs, idxs); - Vals.push_back(elem); - } - - auto agg = ConstantArray::get( - cast(getShadowType(arg->getType())), Vals); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, agg))); - return agg; - } else { - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, cs))); - return cs; - } - } else if (auto fn = dyn_cast(oval)) { - Constant *shadow = GetOrCreateShadowFunction( - RequestContext(nullptr, &BuilderM), Logic, TLI, TA, fn, mode, - runtimeActivity, width, AtomicAdd); - if (width > 1) { - SmallVector arr; - for (unsigned i = 0; i < width; ++i) { - arr.push_back(shadow); - } - ArrayType *arrTy = ArrayType::get(shadow->getType(), width); - shadow = ConstantArray::get(arrTy, arr); - } - return shadow; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); - Value *invertOp = invertPointerM(arg->getOperand(0), bb, nullShadow); - Type *shadowTy = arg->getDestTy(); - - auto rule = [&](Value *invertOp) { - return bb.CreateCast(arg->getOpcode(), invertOp, shadowTy, - arg->getName() + "'ipc"); - }; - - Value *shadow = applyChainRule(shadowTy, bb, rule, invertOp); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); - Value *invertOp = invertPointerM(arg->getOperand(0), bb, nullShadow); - Type *shadowTy = arg->getType(); - - if (mode == DerivativeMode::ReverseModeCombined || - mode == DerivativeMode::ReverseModePrimal || - mode == DerivativeMode::ReverseModeGradient) { - if (TR.query(arg)[{-1}].isFloat()) { - return Constant::getNullValue(getShadowType(oval->getType())); - } - } - assert(!arg->getType()->isDoubleTy()); - - auto rule = [&](Value *invertOp) { - return bb.CreateFreeze(invertOp, arg->getName() + "'ipf"); - }; - - Value *shadow = applyChainRule(shadowTy, bb, rule, invertOp); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(inversionAllocs); - if (arg->getOpcode() == Instruction::Add) { - if (isa(arg->getOperand(0))) { - auto rule = [&arg](Value *ip) { - Constant *invops[2] = {arg->getOperand(0), cast(ip)}; - return arg->getWithOperands(invops); - }; - auto ip = invertPointerM(arg->getOperand(1), bb, nullShadow); - return applyChainRule(arg->getType(), bb, rule, ip); - } - if (isa(arg->getOperand(1))) { - auto rule = [&arg](Value *ip) { - Constant *invops[2] = {cast(ip), arg->getOperand(1)}; - return arg->getWithOperands(invops); - }; - auto ip = invertPointerM(arg->getOperand(0), bb, nullShadow); - return applyChainRule(arg->getType(), bb, rule, ip); - } - } - auto ip = invertPointerM(arg->getOperand(0), bb, nullShadow); - - if (arg->isCast()) { -#if LLVM_VERSION_MAJOR < 17 - if (auto PT = dyn_cast(arg->getType())) { - if (isConstantValue(arg->getOperand(0)) && - PT->getPointerElementType()->isFunctionTy()) { - goto end; - } - } -#endif - if (isa(ip)) { - auto rule = [&arg](Value *ip) { - return ConstantExpr::getCast(arg->getOpcode(), cast(ip), - arg->getType()); - }; - - return applyChainRule(arg->getType(), bb, rule, ip); - - } else { - auto rule = [&](Value *ip) { - return bb.CreateCast((Instruction::CastOps)arg->getOpcode(), ip, - arg->getType(), arg->getName() + "'ipc"); - }; - - Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); - - invertedPointers.insert(std::make_pair( - (const Value *)oval, InvertedPointerVH(this, shadow))); - - return shadow; - } - } else if (arg->getOpcode() == Instruction::GetElementPtr) { - if (auto C = dyn_cast(ip)) { - auto rule = [&arg, &C]() { - SmallVector NewOps; - for (unsigned i = 0, e = arg->getNumOperands(); i != e; ++i) - NewOps.push_back(i == 0 ? C : arg->getOperand(i)); - return cast(arg->getWithOperands(NewOps)); - }; - - return applyChainRule(arg->getType(), bb, rule); - } else { - SmallVector invertargs; - for (unsigned i = 0; i < arg->getNumOperands() - 1; ++i) { - Value *b = getNewFromOriginal(arg->getOperand(1 + i)); - invertargs.push_back(b); - } - - auto rule = [&bb, &arg, &invertargs](Value *ip) { - // TODO mark this the same inbounds as the original - return bb.CreateGEP(cast(ip)->getSourceElementType(), ip, - invertargs, arg->getName() + "'ipg"); - }; - - Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); - - invertedPointers.insert(std::make_pair( - (const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } - } else { - llvm::errs() << *arg << "\n"; - assert(0 && "unhandled"); - } - goto end; - } else if (auto arg = dyn_cast(oval)) { - auto newi = getNewFromOriginal(arg); - IRBuilder<> bb(newi->getNextNode()); - auto ip = invertPointerM(arg->getOperand(0), bb, nullShadow); - - auto rule = [&bb, &arg, &newi, this](Value *ip) -> llvm::Value * { - if (ip == getNewFromOriginal(arg->getOperand(0))) - return newi; - return bb.CreateExtractValue(ip, arg->getIndices(), - arg->getName() + "'ipev"); - }; - - Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); - Value *ivops[2] = {nullptr, nullptr}; - for (int i = 0; i < 2; i++) { - auto op = arg->getOperand(i); - bool subnull = nullShadow; - auto vd = TR.query(op); - if (!TR.anyFloat(op)) - subnull = false; - if (!runtimeActivity && !isa(op)) { - if (isConstantValue(op)) { - if (TR.anyPointer(op) && vd[{-1, -1}] != BaseType::Integer) { - if (!isa(op) && !isa(op)) { - std::string str; - raw_string_ostream ss(str); - ss << "Mismatched activity for: " << *arg - << " const val: " << *op; - if (CustomErrorHandler) - ivops[i] = unwrap(CustomErrorHandler( - str.c_str(), wrap(arg), ErrorType::MixedActivityError, this, - wrap(op), wrap(&bb))); - else - EmitWarning("MixedActivityError", *arg, ss.str()); - } - } - } - } - if (!ivops[i]) { - ivops[i] = invertPointerM(op, bb, subnull); - } - } - - auto rule = [&bb, &arg](Value *ip0, Value *ip1) { - return bb.CreateInsertValue(ip0, ip1, arg->getIndices(), - arg->getName() + "'ipiv"); - }; - - Value *shadow = - applyChainRule(arg->getType(), bb, rule, ivops[0], ivops[1]); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); - auto ip = invertPointerM(arg->getVectorOperand(), bb, nullShadow); - - auto rule = [&](Value *ip) { - return bb.CreateExtractElement(ip, - getNewFromOriginal(arg->getIndexOperand()), - arg->getName() + "'ipee"); - ; - }; - - Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); - Value *op0 = arg->getOperand(0); - Value *op1 = arg->getOperand(1); - Value *op2 = arg->getOperand(2); - auto ip0 = invertPointerM(op0, bb, nullShadow); - auto ip1 = invertPointerM(op1, bb, nullShadow); - - auto rule = [&](Value *ip0, Value *ip1) { - return bb.CreateInsertElement(ip0, ip1, getNewFromOriginal(op2), - arg->getName() + "'ipie"); - }; - - Value *shadow = applyChainRule(arg->getType(), bb, rule, ip0, ip1); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); - Value *op0 = arg->getOperand(0); - Value *op1 = arg->getOperand(1); - auto ip0 = invertPointerM(op0, bb, nullShadow); - auto ip1 = invertPointerM(op1, bb, nullShadow); - - auto rule = [&bb, &arg](Value *ip0, Value *ip1) { - return bb.CreateShuffleVector(ip0, ip1, arg->getShuffleMaskForBitcode(), - arg->getName() + "'ipsv"); - }; - - Value *shadow = applyChainRule(arg->getType(), bb, rule, ip0, ip1); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); - bb.setFastMathFlags(getFast()); - - Value *itval = nullptr; - { - auto tval = arg->getTrueValue(); - if (!runtimeActivity && TR.query(arg)[{-1}].isPossiblePointer() && - !isa(tval) && !isa(tval) && - isConstantValue(tval)) { - std::string str; - raw_string_ostream ss(str); - ss << "Mismatched activity for: " << *arg << " const val: " << *tval; - if (CustomErrorHandler) - itval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), - ErrorType::MixedActivityError, this, - wrap(tval), wrap(&bb))); - else - EmitWarning("MixedActivityError", *arg, ss.str()); - } - if (!itval) { - itval = invertPointerM(tval, bb, nullShadow); - } - } - Value *ifval = nullptr; - { - auto fval = arg->getFalseValue(); - if (!runtimeActivity && TR.query(arg)[{-1}].isPossiblePointer() && - !isa(fval) && !isa(fval) && - isConstantValue(fval)) { - std::string str; - raw_string_ostream ss(str); - ss << "Mismatched activity for: " << *arg << " const val: " << *fval; - if (CustomErrorHandler) - ifval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), - ErrorType::MixedActivityError, this, - wrap(fval), wrap(&bb))); - else - EmitWarning("MixedActivityError", *arg, ss.str()); - } - if (!ifval) { - ifval = invertPointerM(fval, bb, nullShadow); - } - } - - Value *shadow = applyChainRule( - arg->getType(), bb, - [&](Value *tv, Value *fv) { - return bb.CreateSelect(getNewFromOriginal(arg->getCondition()), tv, - fv, arg->getName() + "'ipse"); - }, - itval, ifval); - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); - Value *op0 = arg->getOperand(0); - Value *ip = invertPointerM(op0, bb); - - SmallVector prevScopes; - if (auto prev = arg->getMetadata(LLVMContext::MD_alias_scope)) { - for (auto &M : cast(prev)->operands()) { - prevScopes.push_back(M); - } - } - SmallVector prevNoAlias; - if (auto prev = arg->getMetadata(LLVMContext::MD_noalias)) { - for (auto &M : cast(prev)->operands()) { - prevNoAlias.push_back(M); - } - } - size_t idx = 0; - auto rule = [&](Value *ip) { - auto li = bb.CreateLoad(arg->getType(), ip, arg->getName() + "'ipl"); - llvm::SmallVector ToCopy2(MD_ToCopy); - li->copyMetadata(*arg, ToCopy2); - li->copyIRFlags(arg); - - SmallVector scopeMD = {getDerivativeAliasScope(op0, idx)}; - for (auto M : prevScopes) - scopeMD.push_back(M); - auto scope = MDNode::get(li->getContext(), scopeMD); - li->setMetadata(LLVMContext::MD_alias_scope, scope); - - SmallVector MDs; - for (ssize_t j = -1; j < getWidth(); j++) { - if (j != (ssize_t)idx) - MDs.push_back(getDerivativeAliasScope(op0, j)); - } - for (auto M : prevNoAlias) - MDs.push_back(M); - if (MDs.size()) { - auto noscope = MDNode::get(li->getContext(), MDs); - li->setMetadata(LLVMContext::MD_noalias, noscope); - } - - li->setAlignment(arg->getAlign()); - li->setDebugLoc(getNewFromOriginal(arg->getDebugLoc())); - li->setVolatile(arg->isVolatile()); - li->setOrdering(arg->getOrdering()); - li->setSyncScopeID(arg->getSyncScopeID()); - idx++; - return li; - }; - - Value *li = applyChainRule(arg->getType(), bb, rule, ip); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, li))); - return li; - - } else if (auto arg = dyn_cast(oval)) { - switch (mode) { - case DerivativeMode::ReverseModePrimal: - case DerivativeMode::ReverseModeCombined: - case DerivativeMode::ReverseModeGradient: - if (TR.query(arg)[{-1}].isFloat()) { - return Constant::getNullValue(getShadowType(arg->getType())); - } - break; - default: - break; - } - - if (!arg->getType()->isIntOrIntVectorTy()) { - llvm::errs() << *oval << "\n"; - } - assert(arg->getType()->isIntOrIntVectorTy()); - IRBuilder<> bb(getNewFromOriginal(arg)); - Value *val0 = nullptr; - Value *val1 = nullptr; - - val0 = invertPointerM(arg->getOperand(0), bb); - val1 = invertPointerM(arg->getOperand(1), bb); - assert(val0->getType() == val1->getType()); - - auto rule = [&bb, &arg](Value *val0, Value *val1) { - auto li = bb.CreateBinOp(arg->getOpcode(), val0, val1, arg->getName()); - if (auto BI = dyn_cast(li)) - BI->copyIRFlags(arg); - return li; - }; - - Value *li = applyChainRule(arg->getType(), bb, rule, val0, val1); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, li))); - return li; - } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); - SmallVector invertargs; - for (unsigned i = 0; i < arg->getNumIndices(); ++i) { - Value *b = getNewFromOriginal(arg->getOperand(1 + i)); - invertargs.push_back(b); - } - Value *ip = invertPointerM(arg->getPointerOperand(), bb); - - auto rule = [&](Value *ip) { - auto shadow = bb.CreateGEP(arg->getSourceElementType(), ip, invertargs, - arg->getName() + "'ipg"); - - if (auto gep = dyn_cast(shadow)) - gep->setIsInBounds(arg->isInBounds()); - - return shadow; - }; - - Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } else if (auto inst = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(inst)); - Value *asize = getNewFromOriginal(inst->getArraySize()); - - auto rule1 = [&]() { - AllocaInst *antialloca = bb.CreateAlloca( - inst->getAllocatedType(), inst->getType()->getPointerAddressSpace(), - asize, inst->getName() + "'ipa"); - antialloca->setAlignment(inst->getAlign()); - return antialloca; - }; - - Value *antialloca = applyChainRule(oval->getType(), bb, rule1); - - invertedPointers.insert(std::make_pair( - (const Value *)oval, InvertedPointerVH(this, antialloca))); - - if (auto ci = dyn_cast(asize)) { - if (ci->isOne()) { - - auto rule = [&](Value *antialloca) { - StoreInst *st = bb.CreateStore( - Constant::getNullValue(inst->getAllocatedType()), antialloca); - cast(st)->setAlignment(inst->getAlign()); - }; - - applyChainRule(bb, rule, antialloca); - - return antialloca; - } else { - // TODO handle alloca of size > 1 - } - } - - auto rule2 = [&](Value *antialloca) { - auto dst_arg = - bb.CreateBitCast(antialloca, getInt8PtrTy(oval->getContext())); - auto val_arg = ConstantInt::get(Type::getInt8Ty(oval->getContext()), 0); - auto len_arg = bb.CreateMul( - bb.CreateZExtOrTrunc(asize, Type::getInt64Ty(oval->getContext())), - ConstantInt::get(Type::getInt64Ty(oval->getContext()), - M->getDataLayout().getTypeAllocSizeInBits( - inst->getAllocatedType()) / - 8), - "", true, true); - auto volatile_arg = ConstantInt::getFalse(oval->getContext()); - - Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg}; - Type *tys[] = {dst_arg->getType(), len_arg->getType()}; - auto memset = cast(bb.CreateCall( - getIntrinsicDeclaration(M, Intrinsic::memset, tys), args)); - memset->addParamAttr( - 0, Attribute::getWithAlignment(inst->getContext(), inst->getAlign())); - memset->addParamAttr(0, Attribute::NonNull); - }; - - applyChainRule(bb, rule2, antialloca); - - return antialloca; - } else if (auto II = dyn_cast(oval)) { - if (isIntelSubscriptIntrinsic(*II)) { - IRBuilder<> bb(getNewFromOriginal(II)); - - const std::array idxArgsIndices{{0, 1, 2, 4}}; - const size_t ptrArgIndex = 3; - - SmallVector invertArgs(5); - for (auto i : idxArgsIndices) { - Value *idx = getNewFromOriginal(II->getOperand(i)); - invertArgs[i] = idx; - } - Value *invertPtrArg = invertPointerM(II->getOperand(ptrArgIndex), bb); - invertArgs[ptrArgIndex] = invertPtrArg; - - auto rule = [&](Value *ip) { - auto shadow = bb.CreateCall(II->getCalledFunction(), invertArgs); - assert(isa(shadow)); -#if LLVM_VERSION_MAJOR >= 13 - auto CI = cast(shadow); - // Must copy the elementtype attribute as it is needed by the intrinsic - CI->addParamAttr( - ptrArgIndex, - II->getParamAttr(ptrArgIndex, Attribute::AttrKind::ElementType)); -#endif - return shadow; - }; - - Value *shadow = applyChainRule(II->getType(), bb, rule, invertPtrArg); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); - return shadow; - } - - IRBuilder<> bb(getNewFromOriginal(II)); - bb.setFastMathFlags(getFast()); - switch (II->getIntrinsicID()) { - default: - goto end; -#if LLVM_VERSION_MAJOR < 20 - case Intrinsic::nvvm_ldg_global_i: - case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: -#endif - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: { - return applyChainRule( - II->getType(), bb, - [&](Value *ptr) { - Value *args[] = {ptr, getNewFromOriginal(II->getArgOperand(1))}; - auto li = bb.CreateCall(II->getCalledFunction(), args); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - li->copyMetadata(*II, ToCopy2); - li->setDebugLoc(getNewFromOriginal(II->getDebugLoc())); - return li; - }, - invertPointerM(II->getArgOperand(0), bb)); - case Intrinsic::masked_load: - return applyChainRule( - II->getType(), bb, - [&](Value *ptr, Value *defaultV) { - Value *args[] = {ptr, getNewFromOriginal(II->getArgOperand(1)), - getNewFromOriginal(II->getArgOperand(2)), - defaultV}; - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - auto li = bb.CreateCall(II->getCalledFunction(), args); - li->copyMetadata(*II, ToCopy2); - li->setDebugLoc(getNewFromOriginal(II->getDebugLoc())); - return li; - }, - invertPointerM(II->getArgOperand(0), bb), - invertPointerM(II->getArgOperand(3), bb, nullShadow)); - } - } - } else if (auto phi = dyn_cast(oval)) { - - if (phi->getNumIncomingValues() == 0) { - dumpMap(invertedPointers); - assert(0 && "illegal iv of phi"); - } - std::map> mapped; - for (unsigned int i = 0; i < phi->getNumIncomingValues(); ++i) { - mapped[phi->getIncomingValue(i)].insert(phi->getIncomingBlock(i)); - } - - if (false && mapped.size() == 1) { - return invertPointerM(phi->getIncomingValue(0), BuilderM, nullShadow); - } -#if 0 - else if (false && mapped.size() == 2) { - IRBuilder <> bb(phi); - auto which = bb.CreatePHI(Type::getInt1Ty(phi->getContext()), phi->getNumIncomingValues()); - //TODO this is not recursive - - int cnt = 0; - Value* vals[2]; - for(auto v : mapped) { - assert( cnt <= 1 ); - vals[cnt] = v.first; - for (auto b : v.second) { - which->addIncoming(ConstantInt::get(which->getType(), cnt), b); - } - ++cnt; - } - auto result = BuilderM.CreateSelect(which, invertPointerM(vals[1], BuilderM), invertPointerM(vals[0], BuilderM)); - return result; - } -#endif - - else { - auto NewV = getNewFromOriginal(phi); - IRBuilder<> bb(NewV); - bb.setFastMathFlags(getFast()); - // Note if the original phi node get's scev'd in NewF, it may - // no longer be a phi and we need a new place to insert this phi - // Note that if scev'd this can still be a phi with 0 incoming indicating - // an unnecessary value to be replaced - // TODO consider allowing the inverted pointer to become a scev - if (!isa(NewV) || - cast(NewV)->getNumIncomingValues() == 0) { - bb.SetInsertPoint(bb.GetInsertBlock(), bb.GetInsertBlock()->begin()); - } - - if (EnzymeVectorSplitPhi && width > 1) { - IRBuilder<> postPhi(NewV->getParent()->getFirstNonPHI()); - Type *shadowTy = getShadowType(phi->getType()); - PHINode *tmp = bb.CreatePHI(shadowTy, phi->getNumIncomingValues()); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, tmp))); - - Type *wrappedType = ArrayType::get(phi->getType(), width); - Value *res = UndefValue::get(wrappedType); - - SmallVector invertedVals; - for (unsigned int j = 0; j < phi->getNumIncomingValues(); ++j) { - IRBuilder<> pre( - cast(getNewFromOriginal(phi->getIncomingBlock(j))) - ->getTerminator()); - Value *preval = phi->getIncomingValue(j); - - Value *val = nullptr; - if (!runtimeActivity && TR.query(phi)[{-1}].isPossiblePointer() && - !isa(preval) && !isa(preval) && - isConstantValue(preval)) { - std::string str; - raw_string_ostream ss(str); - ss << "Mismatched activity for: " << *phi - << " const val: " << *preval; - if (CustomErrorHandler) - val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), - ErrorType::MixedActivityError, - this, wrap(preval), wrap(&pre))); - else - EmitWarning("MixedActivityError", *phi, ss.str()); - } - if (!val) { - val = invertPointerM(preval, pre, nullShadow); - } - invertedVals.push_back(val); - } - - for (unsigned int i = 0; i < getWidth(); ++i) { - PHINode *which = - bb.CreatePHI(phi->getType(), phi->getNumIncomingValues()); - which->setDebugLoc(getNewFromOriginal(phi->getDebugLoc())); - - for (unsigned int j = 0; j < phi->getNumIncomingValues(); ++j) { - IRBuilder<> pre( - cast(getNewFromOriginal(phi->getIncomingBlock(j))) - ->getTerminator()); - Value *val = invertedVals[j]; - auto extracted_diff = extractMeta(pre, val, i); - which->addIncoming( - extracted_diff, - cast(getNewFromOriginal(phi->getIncomingBlock(j)))); - } - - res = postPhi.CreateInsertValue(res, which, {i}); - } - invertedPointers.erase((const Value *)oval); - replaceAWithB(tmp, res); - erase(tmp); - - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, res))); - - return res; - } else { - Type *shadowTy = getShadowType(phi->getType()); - PHINode *which = bb.CreatePHI(shadowTy, phi->getNumIncomingValues()); - which->setDebugLoc(getNewFromOriginal(phi->getDebugLoc())); - - invertedPointers.insert(std::make_pair((const Value *)oval, - InvertedPointerVH(this, which))); - - for (unsigned int i = 0; i < phi->getNumIncomingValues(); ++i) { - IRBuilder<> pre( - cast(getNewFromOriginal(phi->getIncomingBlock(i))) - ->getTerminator()); - - Value *preval = phi->getIncomingValue(i); - - Value *val = nullptr; - if (!runtimeActivity && TR.query(phi)[{-1}].isPossiblePointer() && - !isa(preval) && !isa(preval) && - isConstantValue(preval)) { - std::string str; - raw_string_ostream ss(str); - ss << "Mismatched activity for: " << *phi - << " const val: " << *preval; - if (CustomErrorHandler) - val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), - ErrorType::MixedActivityError, - this, wrap(preval), wrap(&pre))); - else - EmitWarning("MixedActivityError", *phi, ss.str()); - } - if (!val) { - val = invertPointerM(preval, pre, nullShadow); - } - - which->addIncoming(val, cast(getNewFromOriginal( - phi->getIncomingBlock(i)))); - } - return which; - } - } - } - -end:; - assert(BuilderM.GetInsertBlock()); - assert(BuilderM.GetInsertBlock()->getParent()); - assert(oval); - - if (isa(oval) && TR.query(oval)[{-1}].isFloat()) { - return Constant::getNullValue(getShadowType(oval->getType())); - } - - if (CustomErrorHandler) { - std::string str; - raw_string_ostream ss(str); - ss << "cannot find shadow for " << *oval; - auto iv = - unwrap(CustomErrorHandler(str.c_str(), wrap(oval), ErrorType::NoShadow, - this, nullptr, wrap(&BuilderM))); - if (iv) { - invertedPointers.insert( - std::make_pair((const Value *)oval, InvertedPointerVH(this, iv))); - return iv; - } - } - - llvm::errs() << *newFunc->getParent() << "\n"; - llvm::errs() << "fn:" << *newFunc << "\noval=" << *oval - << " icv=" << isConstantValue(oval) << "\n"; - for (auto z : invertedPointers) { - llvm::errs() << "available inversion for " << *z.first << " of " - << *z.second << "\n"; - } - assert(0 && "cannot find deal with ptr that isnt arg"); - report_fatal_error("cannot find deal with ptr that isnt arg"); -} - -Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, - const ValueToValueMapTy &incoming_available, - bool tryLegalRecomputeCheck, BasicBlock *scope) { - - assert(mode == DerivativeMode::ReverseModePrimal || - mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ReverseModeCombined); - - assert(val->getName() != ""); - { - auto found = incoming_available.find(val); - if (found != incoming_available.end()) - return found->second; - } - if (isa(val)) { - return val; - } - if (isa(val)) { - return val; - } - if (isa(val)) { - return val; - } - if (isa(val)) { - return val; - } - if (isa(val)) { - return val; - } - if (isa(val)) { - return val; - } - if (isa(val)) { - return val; - } - - if (!isa(val)) { - llvm::errs() << *val << "\n"; - } - - auto inst = cast(val); - if (inversionAllocs && inst->getParent() == inversionAllocs) { - return val; - } - assert(inst->getParent()->getParent() == newFunc); - assert(BuilderM.GetInsertBlock()->getParent() == newFunc); - if (scope == nullptr) - scope = BuilderM.GetInsertBlock(); - assert(scope->getParent() == newFunc); - - bool reduceRegister = false; - - if (EnzymeRegisterReduce) { - if (isNVLoad(inst)) { - reduceRegister = true; - } - if (auto LI = dyn_cast(inst)) { - auto Arch = - llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); - unsigned int SharedAddrSpace = - Arch == Triple::amdgcn - ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local - : 3; - if (cast(LI->getPointerOperand()->getType()) - ->getAddressSpace() == SharedAddrSpace) { - reduceRegister |= tryLegalRecomputeCheck && - legalRecompute(LI, incoming_available, &BuilderM) && - shouldRecompute(LI, incoming_available, &BuilderM); - } - } - if (!inst->mayReadOrWriteMemory()) { - reduceRegister |= tryLegalRecomputeCheck && - legalRecompute(inst, incoming_available, &BuilderM) && - shouldRecompute(inst, incoming_available, &BuilderM); - } - if (this->isOriginalBlock(*BuilderM.GetInsertBlock())) - reduceRegister = false; - } - - if (!reduceRegister) { - if (isOriginalBlock(*BuilderM.GetInsertBlock())) { - if (BuilderM.GetInsertBlock()->size() && - BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) { - Instruction *use = &*BuilderM.GetInsertPoint(); - while (isa(use)) - use = use->getNextNode(); - if (DT.dominates(inst, use)) { - return inst; - } else { - llvm::errs() << *BuilderM.GetInsertBlock()->getParent() << "\n"; - llvm::errs() << "didn't dominate inst: " << *inst - << " point: " << *BuilderM.GetInsertPoint() - << "\nbb: " << *BuilderM.GetInsertBlock() << "\n"; - } - } else { - if (inst->getParent() == BuilderM.GetInsertBlock() || - DT.dominates(inst, BuilderM.GetInsertBlock())) { - // allowed from block domination - return inst; - } else { - llvm::errs() << *BuilderM.GetInsertBlock()->getParent() << "\n"; - llvm::errs() << "didn't dominate inst: " << *inst - << "\nbb: " << *BuilderM.GetInsertBlock() << "\n"; - } - } - // This is a reverse block - } else if (BuilderM.GetInsertBlock() != inversionAllocs) { - // Something in the entry (or anything that dominates all returns, doesn't - // need caching) - BasicBlock *orig = isOriginal(inst->getParent()); - if (!orig) { - llvm::errs() << "oldFunc: " << *oldFunc << "\n"; - llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << "insertBlock: " << *BuilderM.GetInsertBlock() << "\n"; - llvm::errs() << "instP: " << *inst->getParent() << "\n"; - llvm::errs() << "inst: " << *inst << "\n"; - } - assert(orig); - - // TODO upgrade this to be all returns that this could enter from - bool legal = BlocksDominatingAllReturns.count(orig); - if (legal) { - - BasicBlock *forwardBlock = - isOriginal(originalForReverseBlock(*BuilderM.GetInsertBlock())); - assert(forwardBlock); - - // Don't allow this if we're not definitely using the last iteration of - // this value - // + either because the value isn't in a loop - // + or because the forward of the block usage location isn't in a - // loop (thus last iteration) - // + or because the loop nests share no ancestry - - bool loopLegal = true; - for (Loop *idx = OrigLI->getLoopFor(orig); idx != nullptr; - idx = idx->getParentLoop()) { - for (Loop *fdx = OrigLI->getLoopFor(forwardBlock); fdx != nullptr; - fdx = fdx->getParentLoop()) { - if (idx == fdx) { - loopLegal = false; - break; - } - } - } - - if (loopLegal) { - return inst; - } - } - } - } - - if (lookup_cache[BuilderM.GetInsertBlock()].find(val) != - lookup_cache[BuilderM.GetInsertBlock()].end()) { - auto result = lookup_cache[BuilderM.GetInsertBlock()][val]; - if (result == nullptr) { - lookup_cache[BuilderM.GetInsertBlock()].erase(val); - } else { - assert(result); - assert(result->getType()); - result = BuilderM.CreateBitCast(result, val->getType()); - assert(result->getType() == inst->getType()); - return result; - } - } - - ValueToValueMapTy available; - for (auto pair : incoming_available) { - if (pair.second) - assert(pair.first->getType() == pair.second->getType()); - available[pair.first] = pair.second; - } - - { - BasicBlock *forwardPass = BuilderM.GetInsertBlock(); - if (forwardPass != inversionAllocs && !isOriginalBlock(*forwardPass)) { - forwardPass = originalForReverseBlock(*forwardPass); - } - LoopContext lc; - bool inLoop = getContext(forwardPass, lc); - - if (inLoop) { - bool first = true; - for (LoopContext idx = lc;; getContext(idx.parent->getHeader(), idx)) { - if (available.count(idx.var) == 0) { - if (!isOriginalBlock(*BuilderM.GetInsertBlock())) { - available[idx.var] = - BuilderM.CreateLoad(idx.var->getType(), idx.antivaralloc); - } else { - available[idx.var] = idx.var; - } - } - if (!first && idx.var == inst) - return available[idx.var]; - if (first) { - first = false; - } - if (idx.parent == nullptr) - break; - } - } - } - - if (available.count(inst)) { - assert(available[inst]->getType() == inst->getType()); - return available[inst]; - } - - // If requesting loop bound and not available from index per above - // we must be requesting the total size. Rather than generating - // a new lcssa variable, use the existing loop exact bound var - { - LoopContext lc; - bool loopVar = false; - if (getContext(inst->getParent(), lc) && lc.var == inst) { - loopVar = true; - } else if (auto phi = dyn_cast(inst)) { - Value *V = nullptr; - bool legal = true; - for (auto &val : phi->incoming_values()) { - if (isa(val)) - continue; - if (V == nullptr) - V = val; - else if (V != val) { - legal = false; - break; - } - } - if (legal) { - if (auto I = dyn_cast_or_null(V)) { - if (getContext(I->getParent(), lc) && lc.var == I) { - loopVar = true; - } - } - } - } - if (loopVar) { - Value *lim = nullptr; - if (lc.dynamic) { - // Must be in a reverse pass fashion for a lookup to index bound to be - // legal - assert(/*ReverseLimit*/ reverseBlocks.size() > 0); - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - lc.preheader); - lim = lookupValueFromCache( - lc.var->getType(), /*forwardPass*/ false, BuilderM, lctx, - getDynamicLoopLimit(LI.getLoopFor(lc.header)), - /*isi1*/ false, available); - } else { - lim = lookupM(lc.trueLimit, BuilderM); - } - lookup_cache[BuilderM.GetInsertBlock()][val] = lim; - return lim; - } - } - - Instruction *prelcssaInst = inst; - - assert(inst->getName() != ""); - val = fixLCSSA(inst, scope); - if (isa(val)) { - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *BuilderM.GetInsertBlock() << "\n"; - llvm::errs() << *scope << "\n"; - llvm::errs() << *val << " inst " << *inst << "\n"; - assert(0 && "undef value upon lcssa"); - } - inst = cast(val); - assert(prelcssaInst->getType() == inst->getType()); - assert(!this->isOriginalBlock(*BuilderM.GetInsertBlock())); - - // Update index and caching per lcssa - if (lookup_cache[BuilderM.GetInsertBlock()].find(val) != - lookup_cache[BuilderM.GetInsertBlock()].end()) { - auto result = lookup_cache[BuilderM.GetInsertBlock()][val]; - if (result == nullptr) { - lookup_cache[BuilderM.GetInsertBlock()].erase(val); - } else { - assert(result); - assert(result->getType()); - result = BuilderM.CreateBitCast(result, val->getType()); - assert(result->getType() == inst->getType()); - return result; - } - } - - // TODO consider call as part of - bool lrc = false, src = false; - if (tryLegalRecomputeCheck && - (lrc = legalRecompute(prelcssaInst, available, &BuilderM))) { - if ((src = shouldRecompute(prelcssaInst, available, &BuilderM))) { - auto op = unwrapM(prelcssaInst, BuilderM, available, - UnwrapMode::AttemptSingleUnwrap, scope); - if (op) { - assert(op); - assert(op->getType()); - if (op->getType() != inst->getType()) { - llvm::errs() << " op: " << *op << " inst: " << *inst << "\n"; - } - assert(op->getType() == inst->getType()); - if (!reduceRegister) - lookup_cache[BuilderM.GetInsertBlock()][val] = op; - return op; - } - } else { - if (isa(prelcssaInst)) { - } - } - } - - if (auto li = dyn_cast(inst)) - if (auto origInst = dyn_cast_or_null(isOriginal(inst))) { - auto liobj = getBaseObject(li->getPointerOperand()); - - auto orig_liobj = getBaseObject(origInst->getPointerOperand()); - - if (scopeMap.find(inst) == scopeMap.end()) { - for (auto pair : scopeMap) { - if (auto li2 = dyn_cast(const_cast(pair.first))) { - - auto li2obj = getBaseObject(li2->getPointerOperand()); - - if (liobj == li2obj && DT.dominates(li2, li)) { - auto orig2 = dyn_cast_or_null(isOriginal(li2)); - if (!orig2) - continue; - - bool failed = false; - - // llvm::errs() << "found potential candidate loads: oli:" - // << *origInst << " oli2: " << *orig2 << "\n"; - - auto scev1 = OrigSE->getSCEV(origInst->getPointerOperand()); - auto scev2 = OrigSE->getSCEV(orig2->getPointerOperand()); - // llvm::errs() << " scev1: " << *scev1 << " scev2: " << *scev2 - // << "\n"; - - allInstructionsBetween( - *OrigLI, orig2, origInst, [&](Instruction *I) -> bool { - if (I->mayWriteToMemory() && - writesToMemoryReadBy(&TR, *OrigAA, TLI, - /*maybeReader*/ origInst, - /*maybeWriter*/ I)) { - failed = true; - // llvm::errs() << "FAILED: " << *I << "\n"; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (failed) - continue; - - if (auto ar1 = dyn_cast(scev1)) { - if (auto ar2 = dyn_cast(scev2)) { - if (ar1->getStart() != OrigSE->getCouldNotCompute() && - ar1->getStart() == ar2->getStart() && - ar1->getStepRecurrence(*OrigSE) != - OrigSE->getCouldNotCompute() && - ar1->getStepRecurrence(*OrigSE) == - ar2->getStepRecurrence(*OrigSE)) { - - LoopContext l1; - getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), - l1); - LoopContext l2; - getContext(getNewFromOriginal(ar2->getLoop()->getHeader()), - l2); - if (l1.dynamic || l2.dynamic) - continue; - - // TODO IF len(ar2) >= len(ar1) then we can replace li with - // li2 - if (SE.getSCEV(l1.trueLimit) != SE.getCouldNotCompute() && - SE.getSCEV(l1.trueLimit) == SE.getSCEV(l2.trueLimit)) { - // llvm::errs() - // << " step1: " << *ar1->getStepRecurrence(SE) - // << " step2: " << *ar2->getStepRecurrence(SE) << - // "\n"; - - inst = li2; - break; - } - } - } - } - } - } - } - - auto scev1 = OrigSE->getSCEV(origInst->getPointerOperand()); - - auto Arch = - llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); - unsigned int SharedAddrSpace = - Arch == Triple::amdgcn - ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local - : 3; - if (EnzymeSharedForward && scev1 != OrigSE->getCouldNotCompute() && - cast(orig_liobj->getType())->getAddressSpace() == - SharedAddrSpace) { - Value *resultValue = nullptr; - ValueToValueMapTy newavail; - for (const auto &pair : available) { - assert(pair.first->getType() == pair.second->getType()); - newavail[pair.first] = pair.second; - } - allDomPredecessorsOf(origInst, *OrigDT, [&](Instruction *pred) { - if (auto SI = dyn_cast(pred)) { - // auto NewSI = cast(getNewFromOriginal(SI)); - auto si2obj = getBaseObject(SI->getPointerOperand()); - - if (si2obj != orig_liobj) - return false; - - bool lastStore = true; - bool interveningSync = false; - allInstructionsBetween( - *OrigLI, SI, origInst, [&](Instruction *potentialAlias) { - if (!potentialAlias->mayWriteToMemory()) - return false; - if (!writesToMemoryReadBy(&TR, *OrigAA, TLI, origInst, - potentialAlias)) - return false; - - if (auto II = dyn_cast(potentialAlias)) { - if (II->getIntrinsicID() == Intrinsic::nvvm_barrier0 || - II->getIntrinsicID() == Intrinsic::amdgcn_s_barrier) { - interveningSync = - DT.dominates(SI, II) && DT.dominates(II, origInst); - allUnsyncdPredecessorsOf( - II, - [&](Instruction *mid) { - if (!mid->mayWriteToMemory()) - return false; - - if (mid == SI) - return false; - - if (!writesToMemoryReadBy(&TR, *OrigAA, TLI, - origInst, mid)) { - return false; - } - lastStore = false; - return true; - }, - [&]() { - // if gone past entry - if (mode != DerivativeMode::ReverseModeCombined) { - lastStore = false; - } - }); - if (!lastStore) - return true; - else - return false; - } - } - - lastStore = false; - return true; - }); - - if (!lastStore) - return false; - - auto scev2 = OrigSE->getSCEV(SI->getPointerOperand()); - bool legal = scev1 == scev2; - if (auto ar2 = dyn_cast(scev2)) { - if (auto ar1 = dyn_cast(scev1)) { - if (ar2->getStart() != OrigSE->getCouldNotCompute() && - ar1->getStart() == ar2->getStart() && - ar2->getStepRecurrence(*OrigSE) != - OrigSE->getCouldNotCompute() && - ar1->getStepRecurrence(*OrigSE) == - ar2->getStepRecurrence(*OrigSE)) { - - LoopContext l1; - getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), - l1); - LoopContext l2; - getContext(getNewFromOriginal(ar2->getLoop()->getHeader()), - l2); - if (!l1.dynamic && !l2.dynamic) { - // TODO IF len(ar2) >= len(ar1) then we can replace li - // with li2 - if (l1.trueLimit == l2.trueLimit) { - const Loop *L1 = ar1->getLoop(); - while (L1) { - if (L1 == ar2->getLoop()) - return false; - L1 = L1->getParentLoop(); - } - newavail[l2.var] = available[l1.var]; - legal = true; - } - } - } - } - } - if (!legal) { - Value *sval = SI->getPointerOperand(); - Value *lval = origInst->getPointerOperand(); - while (auto CI = dyn_cast(sval)) - sval = CI->getOperand(0); - while (auto CI = dyn_cast(lval)) - lval = CI->getOperand(0); - if (auto sgep = dyn_cast(sval)) { - if (auto lgep = dyn_cast(lval)) { - if (sgep->getPointerOperand() == - lgep->getPointerOperand()) { - SmallVector svals; - for (auto &v : sgep->indices()) { - Value *q = v; - while (auto CI = dyn_cast(q)) - q = CI->getOperand(0); - svals.push_back(q); - } - SmallVector lvals; - for (auto &v : lgep->indices()) { - Value *q = v; - while (auto CI = dyn_cast(q)) - q = CI->getOperand(0); - lvals.push_back(q); - } - ValueToValueMapTy ThreadLookup; - bool legal = true; - for (size_t i = 0; i < svals.size(); i++) { - auto ss = OrigSE->getSCEV(svals[i]); - auto ls = OrigSE->getSCEV(lvals[i]); - if (cast(ss->getType())->getBitWidth() > - cast(ls->getType())->getBitWidth()) { - ls = OrigSE->getZeroExtendExpr(ls, ss->getType()); - } - if (cast(ss->getType())->getBitWidth() < - cast(ls->getType())->getBitWidth()) { - ls = OrigSE->getTruncateExpr(ls, ss->getType()); - } - if (ls != ss) { - if (auto II = dyn_cast(svals[i])) { - switch (II->getIntrinsicID()) { - case Intrinsic::nvvm_read_ptx_sreg_tid_x: - case Intrinsic::nvvm_read_ptx_sreg_tid_y: - case Intrinsic::nvvm_read_ptx_sreg_tid_z: - case Intrinsic::amdgcn_workitem_id_x: - case Intrinsic::amdgcn_workitem_id_y: - case Intrinsic::amdgcn_workitem_id_z: - ThreadLookup[getNewFromOriginal(II)] = - BuilderM.CreateZExtOrTrunc( - lookupM(getNewFromOriginal(lvals[i]), - BuilderM, available), - II->getType()); - break; - default: - legal = false; - break; - } - } else { - legal = false; - break; - } - } - } - if (legal) { - for (auto pair : newavail) { - assert(pair.first->getType() == - pair.second->getType()); - ThreadLookup[pair.first] = pair.second; - } - Value *recomp = unwrapM( - getNewFromOriginal(SI->getValueOperand()), BuilderM, - ThreadLookup, UnwrapMode::AttemptFullUnwrap, scope, - /*permitCache*/ false); - if (recomp) { - resultValue = recomp; - return true; - ; - } - } - } - } - } - } - if (!legal) - return false; - return true; - } - return false; - }); - - if (resultValue) { - if (resultValue->getType() != val->getType()) - resultValue = BuilderM.CreateBitCast(resultValue, val->getType()); - return resultValue; - } - } - } - - auto loadSize = (li->getParent() - ->getParent() - ->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(li->getType()) + - 7) / - 8; - - // this is guarded because havent told cacheForReverse how to move - if (mode == DerivativeMode::ReverseModeCombined) - if (!li->isVolatile() && EnzymeLoopInvariantCache) { - if (auto AI = dyn_cast(liobj)) { - assert(isa(orig_liobj)); - if (auto AT = dyn_cast(AI->getAllocatedType())) - if (auto GEP = - dyn_cast(li->getPointerOperand())) { - if (GEP->getPointerOperand() == AI) { - LoopContext l1; - if (!getContext(li->getParent(), l1)) - goto noSpeedCache; - - BasicBlock *ctx = l1.preheader; - - auto origPH = cast_or_null(isOriginal(ctx)); - assert(origPH); - if (OrigPDT->dominates(origPH, origInst->getParent())) { - goto noSpeedCache; - } - - Instruction *origTerm = origPH->getTerminator(); - if (!origTerm) - llvm::errs() << *origPH << "\n"; - assert(origTerm); - IRBuilder<> OB(origTerm); - LoadInst *tmpload = OB.CreateLoad(AT, orig_liobj, "'tmpload"); - - bool failed = false; - allInstructionsBetween( - *OrigLI, &*origTerm, origInst, - [&](Instruction *I) -> bool { - if (I->mayWriteToMemory() && - writesToMemoryReadBy(&TR, *OrigAA, TLI, - /*maybeReader*/ tmpload, - /*maybeWriter*/ I)) { - failed = true; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (failed) { - tmpload->eraseFromParent(); - goto noSpeedCache; - } - while (Loop *L = LI.getLoopFor(ctx)) { - BasicBlock *nctx = L->getLoopPreheader(); - assert(nctx); - bool failed = false; - auto origPH = cast_or_null(isOriginal(nctx)); - assert(origPH); - if (OrigPDT->dominates(origPH, origInst->getParent())) { - break; - } - Instruction *origTerm = origPH->getTerminator(); - allInstructionsBetween( - *OrigLI, &*origTerm, origInst, - [&](Instruction *I) -> bool { - if (I->mayWriteToMemory() && - writesToMemoryReadBy(&TR, *OrigAA, TLI, - /*maybeReader*/ tmpload, - /*maybeWriter*/ I)) { - failed = true; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (failed) - break; - ctx = nctx; - } - - tmpload->eraseFromParent(); - - IRBuilder<> v(ctx->getTerminator()); - - AllocaInst *cache = nullptr; - - LoopContext tmp; - bool forceSingleIter = false; - if (!getContext(ctx, tmp)) { - forceSingleIter = true; - } - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - ctx, forceSingleIter); - - if (auto found = findInMap(scopeMap, (Value *)liobj)) { - cache = found->first; - } else { - // if freeing reverseblocks must exist - assert(reverseBlocks.size()); - cache = createCacheForScope(lctx, AT, li->getName(), - /*shouldFree*/ true, - /*allocate*/ true); - assert(cache); - scopeMap.insert( - std::make_pair(AI, std::make_pair(cache, lctx))); - - v.setFastMathFlags(getFast()); - assert(isOriginalBlock(*v.GetInsertBlock())); - Value *outer = - getCachePointer(AT, - /*inForwardPass*/ true, v, lctx, cache, - /*storeinstorecache*/ true, - /*available*/ ValueToValueMapTy(), - /*extraSize*/ nullptr); - - auto ld = v.CreateLoad(AT, AI); - ld->setAlignment(AI->getAlign()); - scopeInstructions[cache].push_back(ld); - auto st = v.CreateStore(ld, outer); - auto bsize = newFunc->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(AT) / - 8; - if ((bsize & (bsize - 1)) == 0) { - st->setAlignment(Align(bsize)); - } - scopeInstructions[cache].push_back(st); - for (auto post : PostCacheStore(st, v)) { - scopeInstructions[cache].push_back(post); - } - } - - assert(!isOriginalBlock(*BuilderM.GetInsertBlock())); - Value *outer = getCachePointer( - AT, - /*inForwardPass*/ false, BuilderM, lctx, cache, - /*storeinstorecache*/ true, available, - /*extraSize*/ nullptr); - SmallVector idxs; - for (auto &idx : GEP->indices()) { - idxs.push_back(lookupM(idx, BuilderM, available, - tryLegalRecomputeCheck)); - } - - auto cptr = BuilderM.CreateGEP(GEP->getSourceElementType(), - outer, idxs); - cast(cptr)->setIsInBounds(true); - - // Retrieve the actual result - auto result = loadFromCachePointer(val->getType(), BuilderM, - cptr, cache); - - assert(result->getType() == inst->getType()); - lookup_cache[BuilderM.GetInsertBlock()][val] = result; - return result; - } - } - } - - auto scev1 = OrigSE->getSCEV(origInst->getPointerOperand()); - // Store in memcpy opt - Value *lim = nullptr; - BasicBlock *ctx = nullptr; - Value *start = nullptr; - Value *offset = nullptr; - if (auto ar1 = dyn_cast(scev1)) { - if (auto step = - dyn_cast(ar1->getStepRecurrence(*OrigSE))) { - if (step->getAPInt() != loadSize) - goto noSpeedCache; - - LoopContext l1; - getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), l1); - - if (l1.dynamic) - goto noSpeedCache; - - offset = available[l1.var]; - ctx = l1.preheader; - - IRBuilder<> v(ctx->getTerminator()); - - auto origPH = cast_or_null(isOriginal(ctx)); - assert(origPH); - if (OrigPDT->dominates(origPH, origInst->getParent())) { - goto noSpeedCache; - } - - lim = unwrapM(l1.trueLimit, v, - /*available*/ ValueToValueMapTy(), - UnwrapMode::AttemptFullUnwrapWithLookup); - if (!lim) { - goto noSpeedCache; - } - lim = v.CreateAdd(lim, ConstantInt::get(lim->getType(), 1), "", - true, true); - - { -#if LLVM_VERSION_MAJOR >= 12 - Value *start0; - SmallVector InsertedInstructions; - { - SCEVExpander OrigExp( - *OrigSE, ctx->getParent()->getParent()->getDataLayout(), - "enzyme", /*PreserveLCSSA = */ false); - - OrigExp.setInsertPoint( - isOriginal(l1.header)->getTerminator()); - - start0 = OrigExp.expandCodeFor( - ar1->getStart(), li->getPointerOperand()->getType()); - InsertedInstructions = OrigExp.getAllInsertedInstructions(); - } - - ValueToValueMapTy available; - for (const auto &pair : originalToNewFn) { - if (pair.first->getType() == pair.second->getType()) - available[pair.first] = pair.second; - } - - // Sort so that later instructions do not dominate earlier - // instructions. - llvm::stable_sort(InsertedInstructions, - [this](Instruction *A, Instruction *B) { - return OrigDT->dominates(A, B); - }); - for (auto a : InsertedInstructions) { - if (isa(a)) { - std::string str; - raw_string_ostream ss(str); - ss << "oldFunc: " << *oldFunc << "\n"; - ss << "newFunc: " << *newFunc << "\n"; - ss << "li: " << *li << "\n"; - ss << "start0: " << *start0 << "\n"; - ss << "Inserted a phi node (" << *a - << ") during unwrap of SCEV: " << *ar1->getStart() - << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(li), - ErrorType::InternalError, nullptr, - nullptr, nullptr); - } else { - EmitFailure("InsertedPHISCEV", li->getDebugLoc(), li, - ss.str()); - } - } - auto uwV = - unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap, - /*scope*/ nullptr, /*cache*/ false); - auto uw = dyn_cast(uwV); - assert(uwV->getType() == a->getType()); -#ifndef NDEBUG - if (uw) { - for (size_t i = 0; i < uw->getNumOperands(); i++) { - auto op = uw->getOperand(i); - if (auto arg = dyn_cast(op)) - assert(arg->getParent() == newFunc); - else if (auto inst = dyn_cast(op)) - assert(inst->getParent()->getParent() == newFunc); - } - assert(uw->getParent()->getParent() == newFunc); - } -#endif - available[a] = uwV; - if (uw) - unwrappedLoads.erase(uw); - } - - start = - isa(start0) ? start0 : (Value *)available[start0]; - if (!start) { - llvm::errs() << "old: " << *oldFunc << "\n"; - llvm::errs() << "new: " << *newFunc << "\n"; - llvm::errs() << "start0: " << *start0 << "\n"; - } - assert(start); - - available.clear(); - for (auto I : llvm::reverse(InsertedInstructions)) { - assert(I->getNumUses() == 0); - OrigSE->forgetValue(I); - I->eraseFromParent(); - } -#endif - } - - if (!start) - goto noSpeedCache; - - Instruction *origTerm = origPH->getTerminator(); - - bool failed = false; - allInstructionsBetween( - *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { - if (I->mayWriteToMemory() && - writesToMemoryReadBy(&TR, *OrigAA, TLI, - /*maybeReader*/ origInst, - /*maybeWriter*/ I)) { - failed = true; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (failed) - goto noSpeedCache; - } - } - - if (ctx && lim && start && offset) { - Value *firstLim = lim; - Value *firstStart = start; - while (Loop *L = LI.getLoopFor(ctx)) { - BasicBlock *nctx = L->getLoopPreheader(); - assert(nctx); - bool failed = false; - auto origPH = cast_or_null(isOriginal(nctx)); - assert(origPH); - if (OrigPDT->dominates(origPH, origInst->getParent())) { - break; - } - Instruction *origTerm = origPH->getTerminator(); - allInstructionsBetween( - *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { - if (I->mayWriteToMemory() && - writesToMemoryReadBy(&TR, *OrigAA, TLI, - /*maybeReader*/ origInst, - /*maybeWriter*/ I)) { - failed = true; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (failed) - break; - IRBuilder<> nv(nctx->getTerminator()); - Value *nlim = unwrapM(firstLim, nv, - /*available*/ ValueToValueMapTy(), - UnwrapMode::AttemptFullUnwrapWithLookup); - if (!nlim) - break; - Value *nstart = unwrapM(firstStart, nv, - /*available*/ ValueToValueMapTy(), - UnwrapMode::AttemptFullUnwrapWithLookup); - if (!nstart) - break; - lim = nlim; - start = nstart; - ctx = nctx; - } - IRBuilder<> v(ctx->getTerminator()); - bool isi1 = val->getType()->isIntegerTy() && - cast(li->getType())->getBitWidth() == 1; - - AllocaInst *cache = nullptr; - - LoopContext tmp; - bool forceSingleIter = false; - if (!getContext(ctx, tmp)) { - forceSingleIter = true; - } else if (auto inst = dyn_cast(lim)) { - if (inst->getParent() == ctx || - !DT.dominates(inst->getParent(), ctx)) { - forceSingleIter = true; - } - } - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, ctx, - forceSingleIter); - - if (auto found = findInMap(scopeMap, (Value *)inst)) { - cache = found->first; - } else { - // if freeing reverseblocks must exist - assert(reverseBlocks.size()); - cache = createCacheForScope(lctx, li->getType(), li->getName(), - /*shouldFree*/ true, - /*allocate*/ true, /*extraSize*/ lim); - assert(cache); - scopeMap.insert( - std::make_pair(inst, std::make_pair(cache, lctx))); - - v.setFastMathFlags(getFast()); - assert(isOriginalBlock(*v.GetInsertBlock())); - Value *outer = - getCachePointer(li->getType(), - /*inForwardPass*/ true, v, lctx, cache, - /*storeinstorecache*/ true, - /*available*/ ValueToValueMapTy(), - /*extraSize*/ lim); - - auto dst_arg = v.CreateBitCast( - outer, - getInt8PtrTy( - inst->getContext(), - cast(outer->getType())->getAddressSpace())); - scopeInstructions[cache].push_back(cast(dst_arg)); - auto src_arg = v.CreateBitCast( - start, - getInt8PtrTy( - inst->getContext(), - cast(start->getType())->getAddressSpace())); - auto len_arg = - v.CreateMul(ConstantInt::get(lim->getType(), loadSize), lim, - "", true, true); - if (Instruction *I = dyn_cast(len_arg)) - scopeInstructions[cache].push_back(I); - auto volatile_arg = ConstantInt::getFalse(inst->getContext()); - - Value *nargs[] = {dst_arg, src_arg, len_arg, volatile_arg}; - - Type *tys[] = {dst_arg->getType(), src_arg->getType(), - len_arg->getType()}; - - auto memcpyF = getIntrinsicDeclaration(newFunc->getParent(), - Intrinsic::memcpy, tys); - auto mem = cast(v.CreateCall(memcpyF, nargs)); - - mem->addParamAttr(0, Attribute::NonNull); - mem->addParamAttr(1, Attribute::NonNull); - - auto bsize = - newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits( - li->getType()) / - 8; - if ((bsize & (bsize - 1)) == 0) { - mem->addParamAttr(0, Attribute::getWithAlignment( - memcpyF->getContext(), Align(bsize))); - } - - mem->addParamAttr(1, Attribute::getWithAlignment( - memcpyF->getContext(), li->getAlign())); - scopeInstructions[cache].push_back(mem); - } - - assert(!isOriginalBlock(*BuilderM.GetInsertBlock())); - Value *result = lookupValueFromCache( - inst->getType(), - /*isForwardPass*/ false, BuilderM, lctx, cache, isi1, available, - /*extraSize*/ lim, offset); - assert(result->getType() == inst->getType()); - lookup_cache[BuilderM.GetInsertBlock()][val] = result; - - EmitWarning("Uncacheable", *inst, "Caching instruction ", *inst, - " legalRecompute: ", lrc, " shouldRecompute: ", src, - " tryLegalRecomputeCheck: ", tryLegalRecomputeCheck); - return result; - } - } - noSpeedCache:; - } - - if (scopeMap.find(inst) == scopeMap.end()) { - EmitWarning("Uncacheable", *inst, "Caching instruction ", *inst, - " legalRecompute: ", lrc, " shouldRecompute: ", src, - " tryLegalRecomputeCheck: ", tryLegalRecomputeCheck); - } - - BasicBlock *scopeI = inst->getParent(); - if (auto origInst = isOriginal(inst)) { - auto found = rematerializableAllocations.find(origInst); - if (found != rematerializableAllocations.end()) - if (found->second.LI && found->second.LI->contains(origInst)) { - // If not caching whole allocation and rematerializing the allocation - // within the loop, force an entry-level scope so there is no need - // to cache. - if (!needsCacheWholeAllocation(origInst)) - scopeI = &newFunc->getEntryBlock(); - } - } else { - for (auto pair : backwardsOnlyShadows) { - if (auto pinst = dyn_cast(pair.first)) - if (!pair.second.primalInitialize && pair.second.LI && - pair.second.LI->contains(pinst->getParent())) { - auto found = invertedPointers.find(pair.first); - if (found != invertedPointers.end() && found->second == inst) { - scopeI = &newFunc->getEntryBlock(); - - // Prevent the phi node from being stored into the cache by creating - // it before the ensureLookupCached. - if (scopeMap.find(inst) == scopeMap.end()) { - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - scopeI); - - AllocaInst *cache = createCacheForScope( - lctx, inst->getType(), inst->getName(), /*shouldFree*/ true); - assert(cache); - Value *inst_tmp = inst; - insert_or_assign(scopeMap, inst_tmp, - std::pair, LimitContext>( - cache, lctx)); - } - break; - } - } - } - } - - ensureLookupCached(inst, /*shouldFree*/ true, scopeI, - inst->getMetadata(LLVMContext::MD_tbaa)); - bool isi1 = inst->getType()->isIntegerTy() && - cast(inst->getType())->getBitWidth() == 1; - assert(!isOriginalBlock(*BuilderM.GetInsertBlock())); - auto found = findInMap(scopeMap, (Value *)inst); - Value *result = - lookupValueFromCache(inst->getType(), /*isForwardPass*/ false, BuilderM, - found->second, found->first, isi1, available); - if (auto LI2 = dyn_cast(result)) - if (auto LI1 = dyn_cast(inst)) { - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - ToCopy2.push_back(LLVMContext::MD_alias_scope); - LI2->copyMetadata(*LI1, ToCopy2); - } - if (result->getType() != inst->getType()) { - llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << "result: " << *result << "\n"; - llvm::errs() << "inst: " << *inst << "\n"; - llvm::errs() << "val: " << *val << "\n"; - } - assert(result->getType() == inst->getType()); - lookup_cache[BuilderM.GetInsertBlock()][val] = result; - assert(result); - if (result->getType() != val->getType()) { - result = BuilderM.CreateBitCast(result, val->getType()); - } - assert(result->getType() == val->getType()); - assert(result->getType()); - return result; -} - -BasicBlock *GradientUtils::originalForReverseBlock(BasicBlock &BB2) const { - auto found = reverseBlockToPrimal.find(&BB2); - if (found == reverseBlockToPrimal.end()) { - errs() << "newFunc: " << *newFunc << "\n"; - errs() << BB2 << "\n"; - } - assert(found != reverseBlockToPrimal.end()); - return found->second; -} - -//! Given a map of edges we could have taken to desired target, compute a value -//! that determines which target should be branched to -// This function attempts to determine an equivalent condition from earlier in -// the code and use that if possible, falling back to creating a phi node of -// which edge was taken if necessary This function can be used in two ways: -// * If replacePHIs is null (usual case), this function does the branch -// * If replacePHIs isn't null, do not perform the branch and instead replace -// the PHI's with the derived condition as to whether we should branch to a -// particular target -void GradientUtils::branchToCorrespondingTarget( - BasicBlock *ctx, IRBuilder<> &BuilderM, - const std::map>> - &targetToPreds, - const std::map *replacePHIs) { - assert(targetToPreds.size() > 0); - if (replacePHIs) { - if (replacePHIs->size() == 0) - return; - -#ifndef NDEBUG - for (auto x : *replacePHIs) { - assert(targetToPreds.find(x.first) != targetToPreds.end()); - } -#endif - } - - if (targetToPreds.size() == 1) { - if (replacePHIs == nullptr) { - if (!(BuilderM.GetInsertBlock()->size() == 0 || - !isa(BuilderM.GetInsertBlock()->back()))) { - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *BuilderM.GetInsertBlock() << "\n"; - } - assert(BuilderM.GetInsertBlock()->size() == 0 || - !isa(BuilderM.GetInsertBlock()->back())); - BuilderM.CreateBr(targetToPreds.begin()->first); - } else { - for (auto pair : *replacePHIs) { - pair.second->replaceAllUsesWith( - ConstantInt::getTrue(pair.second->getContext())); - pair.second->eraseFromParent(); - } - } - return; - } - - // Map of function edges to list of targets this can branch to we have - std::map, - std::set> - done; - { - std::deque< - std::tuple, - BasicBlock *>> - Q; // newblock, target - - for (auto pair : targetToPreds) { - for (auto pred_edge : pair.second) { - Q.push_back(std::make_pair(pred_edge, pair.first)); - } - } - - for (std::tuple< - std::pair, - BasicBlock *> - trace; - Q.size() > 0;) { - trace = Q.front(); - Q.pop_front(); - auto edge = std::get<0>(trace); - auto block = edge.first; - auto target = std::get<1>(trace); - - if (done[edge].count(target)) - continue; - done[edge].insert(target); - - // If this block dominates the context, don't go back up as any - // predecessors won't contain the conditions. - if (DT.dominates(block, ctx)) - continue; - - Loop *blockLoop = LI.getLoopFor(block); - - for (BasicBlock *Pred : predecessors(block)) { - // Don't go up the backedge as we can use the last value if desired via - // lcssa - if (blockLoop && blockLoop->getHeader() == block && - blockLoop == LI.getLoopFor(Pred)) - continue; - - Q.push_back( - std::tuple, BasicBlock *>( - std::make_pair(Pred, block), target)); - } - } - } - - IntegerType *T; - if (targetToPreds.size() == 2) - T = Type::getInt1Ty(BuilderM.getContext()); - else if (targetToPreds.size() < 256) - T = Type::getInt8Ty(BuilderM.getContext()); - else - T = Type::getInt32Ty(BuilderM.getContext()); - - Instruction *equivalentTerminator = nullptr; - - std::set blocks; - - // llvm::errs() << "\n\ngetName() << ">\n"; - for (auto pair : done) { - const auto &edge = pair.first; - blocks.insert(edge.first); - // llvm::errs() << " edge (" << edge.first->getName() << ", " - // << edge.second->getName() << ") : ["; - // for (auto s : pair.second) - // llvm::errs() << s->getName() << ","; - // llvm::errs() << "]\n"; - } - // llvm::errs() << "\n"; - - if (targetToPreds.size() == 3) { - // Try `block` as a potential first split point. - for (auto block : blocks) { - { - // The original split block must not have a parent with an edge - // to a block other than to itself, which can reach any targets. - if (!DT.dominates(block, ctx)) - continue; - - // For all successors and thus edges (block, succ): - // 1) Ensure that no successors have overlapping potential - // destinations (a list of destinations previously seen is in - // foundtargets). - // 2) The block branches to all 3 destinations (foundTargets==3) - std::set foundtargets; - // 3) The unique target split off from the others is stored in - // uniqueTarget. - std::set uniqueTargets; - for (BasicBlock *succ : successors(block)) { - auto edge = std::make_pair(block, succ); - for (BasicBlock *target : done[edge]) { - if (foundtargets.find(target) != foundtargets.end()) { - goto rnextpair; - } - foundtargets.insert(target); - if (done[edge].size() == 1) - uniqueTargets.insert(target); - } - } - if (foundtargets.size() != 3) - goto rnextpair; - if (uniqueTargets.size() != 1) - goto rnextpair; - - // Only handle cases where the split was due to a conditional - // branch. This branch, `bi`, splits off uniqueTargets[0] from - // the remainder of foundTargets. - auto bi1 = dyn_cast(block->getTerminator()); - if (!bi1) - goto rnextpair; - - { - // Find a second block `subblock` which splits the two merged - // targets from each other. - BasicBlock *subblock = nullptr; - for (auto block2 : blocks) { - { - // The second split block must not have a parent with an edge - // to a block other than to itself, which can reach any of its two - // targets. - // TODO verify this - for (auto P : predecessors(block2)) { - for (auto S : successors(P)) { - if (S == block2) - continue; - auto edge = std::make_pair(P, S); - if (done.find(edge) != done.end()) { - for (auto target : done[edge]) { - if (foundtargets.find(target) != foundtargets.end() && - uniqueTargets.find(target) == uniqueTargets.end()) { - goto nextblock; - } - } - } - } - } - - // Again, a successful split must have unique targets. - std::set seen2; - for (BasicBlock *succ : successors(block2)) { - auto edge = std::make_pair(block2, succ); - // Since there are only two targets, a successful split - // condition has only 1 target per successor of block2. - if (done[edge].size() != 1) { - goto nextblock; - } - for (BasicBlock *target : done[edge]) { - // block2 has non-unique targets. - if (seen2.find(target) != seen2.end()) { - goto nextblock; - } - seen2.insert(target); - // block2 has a target which is not part of the two needing - // to be split. The two needing to be split is equal to - // foundtargets-uniqueTargets. - if (foundtargets.find(target) == foundtargets.end()) { - goto nextblock; - } - if (uniqueTargets.find(target) != uniqueTargets.end()) { - goto nextblock; - } - } - } - // If we didn't find two valid successors, continue. - if (seen2.size() != 2) { - // llvm::errs() << " -- failed from not 2 seen\n"; - goto nextblock; - } - subblock = block2; - break; - } - nextblock:; - } - - // If no split block was found, try again. - if (subblock == nullptr) - goto rnextpair; - - // This branch, `bi2`, splits off the two blocks in - // (foundTargets-uniqueTargets) from each other. - auto bi2 = dyn_cast(subblock->getTerminator()); - if (!bi2) - goto rnextpair; - - // Condition cond1 splits off uniqueTargets[0] from - // the remainder of foundTargets. - auto cond1 = lookupM(bi1->getCondition(), BuilderM); - - // Condition cond2 splits off the two blocks in - // (foundTargets-uniqueTargets) from each other. - auto cond2 = lookupM(bi2->getCondition(), BuilderM); - - if (replacePHIs == nullptr) { - BasicBlock *staging = - BasicBlock::Create(oldFunc->getContext(), "staging", newFunc); - auto stagingIfNeeded = [&](BasicBlock *B) { - auto edge = std::make_pair(block, B); - if (done[edge].size() == 1) { - return *done[edge].begin(); - } else { - assert(done[edge].size() == 2); - return staging; - } - }; - BuilderM.CreateCondBr(cond1, stagingIfNeeded(bi1->getSuccessor(0)), - stagingIfNeeded(bi1->getSuccessor(1))); - BuilderM.SetInsertPoint(staging); - BuilderM.CreateCondBr( - cond2, - *done[std::make_pair(subblock, bi2->getSuccessor(0))].begin(), - *done[std::make_pair(subblock, bi2->getSuccessor(1))].begin()); - } else { - Value *otherBranch = nullptr; - for (unsigned i = 0; i < 2; ++i) { - Value *val = cond1; - if (i == 1) - val = BuilderM.CreateNot(val, "anot1_"); - auto edge = std::make_pair(block, bi1->getSuccessor(i)); - if (done[edge].size() == 1) { - auto found = replacePHIs->find(*done[edge].begin()); - if (found == replacePHIs->end()) - continue; - if (&*BuilderM.GetInsertPoint() == found->second) { - if (found->second->getNextNode()) - BuilderM.SetInsertPoint(found->second->getNextNode()); - else - BuilderM.SetInsertPoint(found->second->getParent()); - } - found->second->replaceAllUsesWith(val); - found->second->eraseFromParent(); - } else { - otherBranch = val; - } - } - - for (unsigned i = 0; i < 2; ++i) { - auto edge = std::make_pair(subblock, bi2->getSuccessor(i)); - auto found = replacePHIs->find(*done[edge].begin()); - if (found == replacePHIs->end()) - continue; - - Value *val = cond2; - if (i == 1) - val = BuilderM.CreateNot(val, "bnot1_"); - val = BuilderM.CreateAnd(val, otherBranch, "andVal" + Twine(i)); - if (&*BuilderM.GetInsertPoint() == found->second) { - if (found->second->getNextNode()) - BuilderM.SetInsertPoint(found->second->getNextNode()); - else - BuilderM.SetInsertPoint(found->second->getParent()); - } - found->second->replaceAllUsesWith(val); - found->second->eraseFromParent(); - } - } - - return; - } - } - rnextpair:; - } - } - - BasicBlock *forwardBlock = BuilderM.GetInsertBlock(); - - if (!isOriginalBlock(*forwardBlock)) { - forwardBlock = originalForReverseBlock(*forwardBlock); - } - - for (auto block : blocks) { - { - // The original split block must not have a parent with an edge - // to a block other than to itself, which can reach any targets. - if (!DT.dominates(block, ctx)) - for (auto P : predecessors(block)) { - for (auto S : successors(P)) { - if (S == block) - continue; - auto edge = std::make_pair(P, S); - if (done.find(edge) != done.end() && done[edge].size()) - goto nextpair; - } - } - - std::set foundtargets; - for (BasicBlock *succ : successors(block)) { - auto edge = std::make_pair(block, succ); - if (done[edge].size() != 1) { - goto nextpair; - } - BasicBlock *target = *done[edge].begin(); - if (foundtargets.find(target) != foundtargets.end()) { - goto nextpair; - } - foundtargets.insert(target); - } - if (foundtargets.size() != targetToPreds.size()) { - goto nextpair; - } - - if (forwardBlock == block || DT.dominates(block, forwardBlock)) { - equivalentTerminator = block->getTerminator(); - goto fast; - } - } - nextpair:; - } - goto nofast; - -fast:; - assert(equivalentTerminator); - - if (auto branch = dyn_cast(equivalentTerminator)) { - BasicBlock *block = equivalentTerminator->getParent(); - assert(branch->getCondition()); - - assert(branch->getCondition()->getType() == T); - - if (replacePHIs == nullptr) { - if (!(BuilderM.GetInsertBlock()->size() == 0 || - !isa(BuilderM.GetInsertBlock()->back()))) { - llvm::errs() << "newFunc : " << *newFunc << "\n"; - llvm::errs() << "blk : " << *BuilderM.GetInsertBlock() << "\n"; - } - assert(BuilderM.GetInsertBlock()->size() == 0 || - !isa(BuilderM.GetInsertBlock()->back())); - BuilderM.CreateCondBr( - lookupM(branch->getCondition(), BuilderM), - *done[std::make_pair(block, branch->getSuccessor(0))].begin(), - *done[std::make_pair(block, branch->getSuccessor(1))].begin()); - } else { - for (auto pair : *replacePHIs) { - Value *phi = lookupM(branch->getCondition(), BuilderM); - Value *val = nullptr; - if (pair.first == - *done[std::make_pair(block, branch->getSuccessor(0))].begin()) { - val = phi; - } else if (pair.first == - *done[std::make_pair(block, branch->getSuccessor(1))] - .begin()) { - val = BuilderM.CreateNot(phi); - } else { - llvm::errs() << *pair.first->getParent() << "\n"; - llvm::errs() << *pair.first << "\n"; - llvm::errs() << *branch << "\n"; - llvm_unreachable("unknown successor for replacephi"); - } - if (&*BuilderM.GetInsertPoint() == pair.second) { - if (pair.second->getNextNode()) - BuilderM.SetInsertPoint(pair.second->getNextNode()); - else - BuilderM.SetInsertPoint(pair.second->getParent()); - } - pair.second->replaceAllUsesWith(val); - pair.second->eraseFromParent(); - } - } - } else if (auto si = dyn_cast(equivalentTerminator)) { - BasicBlock *block = equivalentTerminator->getParent(); - - IRBuilder<> pbuilder(equivalentTerminator); - pbuilder.setFastMathFlags(getFast()); - - if (replacePHIs == nullptr) { - SwitchInst *swtch = BuilderM.CreateSwitch( - lookupM(si->getCondition(), BuilderM), - *done[std::make_pair(block, si->getDefaultDest())].begin()); - for (auto switchcase : si->cases()) { - swtch->addCase( - switchcase.getCaseValue(), - *done[std::make_pair(block, switchcase.getCaseSuccessor())] - .begin()); - } - } else { - for (auto pair : *replacePHIs) { - Value *cas = nullptr; - for (auto c : si->cases()) { - if (pair.first == - *done[std::make_pair(block, c.getCaseSuccessor())].begin()) { - cas = c.getCaseValue(); - break; - } - } - if (cas == nullptr) { - assert(pair.first == - *done[std::make_pair(block, si->getDefaultDest())].begin()); - } - Value *val = nullptr; - Value *phi = lookupM(si->getCondition(), BuilderM); - - if (cas) { - val = BuilderM.CreateICmpEQ(cas, phi); - } else { - // default case - val = ConstantInt::getFalse(pair.second->getContext()); - for (auto switchcase : si->cases()) { - val = BuilderM.CreateOr( - val, BuilderM.CreateICmpEQ(switchcase.getCaseValue(), phi)); - } - val = BuilderM.CreateNot(val); - } - if (&*BuilderM.GetInsertPoint() == pair.second) { - if (pair.second->getNextNode()) - BuilderM.SetInsertPoint(pair.second->getNextNode()); - else - BuilderM.SetInsertPoint(pair.second->getParent()); - } - pair.second->replaceAllUsesWith(val); - pair.second->eraseFromParent(); - } - } - } else { - llvm::errs() << "unknown equivalent terminator\n"; - llvm::errs() << *equivalentTerminator << "\n"; - llvm_unreachable("unknown equivalent terminator"); - } - return; - -nofast:; - - // if freeing reverseblocks must exist - assert(reverseBlocks.size()); - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, ctx); - AllocaInst *cache = createCacheForScope(lctx, T, "", /*shouldFree*/ true); - SmallVector targets; - { - size_t idx = 0; - std::map /*predecessors*/>> - storing; - for (const auto &pair : targetToPreds) { - for (auto pred : pair.second) { - storing[pred.first][ConstantInt::get(T, idx)].push_back(pred.second); - } - targets.push_back(pair.first); - ++idx; - } - assert(targets.size() > 0); - - for (const auto &pair : storing) { - IRBuilder<> pbuilder(pair.first); - - if (pair.first->getTerminator()) - pbuilder.SetInsertPoint(pair.first->getTerminator()); - - pbuilder.setFastMathFlags(getFast()); - - Value *tostore = ConstantInt::get(T, 0); - - if (pair.second.size() == 1) { - tostore = pair.second.begin()->first; - } else { - assert(0 && "multi exit edges not supported"); - exit(1); - // for(auto targpair : pair.second) { - // tostore = pbuilder.CreateOr(tostore, pred); - //} - } - storeInstructionInCache(lctx, pbuilder, tostore, cache); - } - } - - bool isi1 = T->isIntegerTy() && cast(T)->getBitWidth() == 1; - Value *which = lookupValueFromCache( - T, - /*forwardPass*/ isOriginalBlock(*BuilderM.GetInsertBlock()), BuilderM, - LimitContext(/*reversePass*/ reverseBlocks.size() > 0, ctx), cache, isi1, - /*available*/ ValueToValueMapTy()); - assert(which); - assert(which->getType() == T); - - if (replacePHIs == nullptr) { - if (targetToPreds.size() == 2) { - assert(BuilderM.GetInsertBlock()->size() == 0 || - !isa(BuilderM.GetInsertBlock()->back())); - BuilderM.CreateCondBr(which, /*true*/ targets[1], /*false*/ targets[0]); - } else { - assert(targets.size() > 0); - auto swit = - BuilderM.CreateSwitch(which, targets.back(), targets.size() - 1); - for (unsigned i = 0; i < targets.size() - 1; ++i) { - swit->addCase(ConstantInt::get(T, i), targets[i]); - } - } - } else { - for (unsigned i = 0; i < targets.size(); ++i) { - auto found = replacePHIs->find(targets[i]); - if (found == replacePHIs->end()) - continue; - - Value *val = nullptr; - if (targets.size() == 2 && i == 0) { - val = BuilderM.CreateNot(which); - } else if (targets.size() == 2 && i == 1) { - val = which; - } else { - val = BuilderM.CreateICmpEQ(ConstantInt::get(T, i), which); - } - if (&*BuilderM.GetInsertPoint() == found->second) { - if (found->second->getNextNode()) - BuilderM.SetInsertPoint(found->second->getNextNode()); - else - BuilderM.SetInsertPoint(found->second->getParent()); - } - found->second->replaceAllUsesWith(val); - found->second->eraseFromParent(); - } - } - return; -} - -void GradientUtils::computeMinCache() { - if (EnzymeMinCutCache) { - SetVector Recomputes; - - std::map FullSeen; - std::map OneLevelSeen; - - ValueToValueMapTy Available; - - std::map> LoopAvail; - - for (BasicBlock &BB : *oldFunc) { - if (notForAnalysis.count(&BB)) - continue; - auto L = OrigLI->getLoopFor(&BB); - - auto invariant = [&](Value *V) { - if (isa(V)) - return true; - if (isa(V)) - return true; - if (auto I = dyn_cast(V)) { - if (!L->contains(OrigLI->getLoopFor(I->getParent()))) - return true; - } - return false; - }; - for (Instruction &I : BB) { - if (auto PN = dyn_cast(&I)) { - if (!OrigLI->isLoopHeader(&BB)) - continue; - if (PN->getType()->isIntegerTy()) { - bool legal = true; - SmallPtrSet Increment; - for (auto B : PN->blocks()) { - if (OrigLI->getLoopFor(B) == L) { - if (auto BO = dyn_cast( - PN->getIncomingValueForBlock(B))) { - if (BO->getOpcode() == BinaryOperator::Add) { - if ((BO->getOperand(0) == PN && - invariant(BO->getOperand(1))) || - (BO->getOperand(1) == PN && - invariant(BO->getOperand(0)))) { - Increment.insert(BO); - } else { - legal = false; - } - } else if (BO->getOpcode() == BinaryOperator::Sub) { - if (BO->getOperand(0) == PN && - invariant(BO->getOperand(1))) { - Increment.insert(BO); - } else { - legal = false; - } - } else { - legal = false; - } - } else { - legal = false; - } - } - } - if (legal) { - LoopAvail[L].insert(PN); - for (auto I : Increment) - LoopAvail[L].insert(I); - } - } - } else if (auto CI = dyn_cast(&I)) { - StringRef funcName = getFuncNameFromCall(CI); - if (isAllocationFunction(funcName, TLI)) { - bool legal = true; - auto found = rematerializableAllocations.find(CI); - if (found != rematerializableAllocations.end()) { - if (found->second.nonRepeatableWritingCall) - legal = false; - } - if (legal) - Available[CI] = CI; - } - } - } - } - - SmallPtrSet NewLoopBoundReq; - { - std::deque LoopBoundRequirements; - - for (auto &context : loopContexts) { - for (auto val : {context.second.maxLimit, context.second.trueLimit}) { - if (val) - if (auto inst = dyn_cast(&*val)) { - LoopBoundRequirements.push_back(inst); - } - } - } - SmallPtrSet Seen; - while (LoopBoundRequirements.size()) { - Instruction *val = LoopBoundRequirements.front(); - LoopBoundRequirements.pop_front(); - if (NewLoopBoundReq.count(val)) - continue; - if (Seen.count(val)) - continue; - Seen.insert(val); - if (auto orig = isOriginal(val)) { - NewLoopBoundReq.insert(orig); - } else { - for (auto &op : val->operands()) { - if (auto inst = dyn_cast(op)) { - LoopBoundRequirements.push_back(inst); - } - } - } - } - for (auto inst : NewLoopBoundReq) { - OneLevelSeen[UsageKey(inst, QueryType::Primal)] = true; - FullSeen[UsageKey(inst, QueryType::Primal)] = true; - } - } - - auto minCutMode = (mode == DerivativeMode::ReverseModePrimal) - ? DerivativeMode::ReverseModeGradient - : mode; - - for (BasicBlock &BB : *oldFunc) { - if (notForAnalysis.count(&BB)) - continue; - ValueToValueMapTy Available2; - for (auto a : Available) - Available2[a.first] = a.second; - for (Loop *L = OrigLI->getLoopFor(&BB); L != nullptr; - L = L->getParentLoop()) { - for (auto v : LoopAvail[L]) { - Available2[v] = v; - } - } - for (Instruction &I : BB) { - if (!legalRecompute(&I, Available2, nullptr)) { - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(this, &I, minCutMode, FullSeen, - notForAnalysis)) { - bool oneneed = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal, - /*OneLevel*/ true>(this, &I, minCutMode, OneLevelSeen, - notForAnalysis); - if (oneneed) { - knownRecomputeHeuristic[&I] = false; - - CountTrackedPointers T(I.getType()); - assert(!T.derived); - } else - Recomputes.insert(&I); - } - } - } - } - - SetVector Intermediates; - SetVector Required; - std::deque todo(Recomputes.begin(), Recomputes.end()); - - while (todo.size()) { - Value *V = todo.front(); - todo.pop_front(); - if (Intermediates.count(V)) - continue; - bool multiLevel = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(this, V, minCutMode, FullSeen, notForAnalysis); - if (!multiLevel) { - continue; - } - if (!Recomputes.count(V)) { - ValueToValueMapTy Available2; - for (auto a : Available) - Available2[a.first] = a.second; - for (Loop *L = OrigLI->getLoopFor(cast(V)->getParent()); - L != nullptr; L = L->getParentLoop()) { - for (auto v : LoopAvail[L]) { - Available2[v] = v; - } - } - if (!legalRecompute(V, Available2, nullptr)) { - // if not legal to recompute, we would've already explicitly marked - // this for caching if it was needed in reverse pass - continue; - } - } - Intermediates.insert(V); - bool singleLevel = DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal, /*OneLevel*/ true>(this, V, minCutMode, - OneLevelSeen, notForAnalysis); - if (singleLevel) { - Required.insert(V); - } else { - DifferentialUseAnalysis::forEachDifferentialUser( - [&](Value *V2) { todo.push_back(V2); }, this, V); - } - } - - SetVector MinReq; - DifferentialUseAnalysis::minCut(oldFunc->getParent()->getDataLayout(), - *OrigLI, Recomputes, Intermediates, - Required, MinReq, this, TLI); - SmallPtrSet NeedGraph; - for (Value *V : MinReq) - NeedGraph.insert(V); - for (Value *V : Required) - todo.push_back(V); - while (todo.size()) { - Value *V = todo.front(); - todo.pop_front(); - if (NeedGraph.count(V)) - continue; - NeedGraph.insert(V); - if (auto I = dyn_cast(V)) - for (auto &V2 : I->operands()) { - if (Intermediates.count(V2)) - todo.push_back(V2); - } - } - - for (auto V : Intermediates) { - knownRecomputeHeuristic[V] = !MinReq.count(V); - if (!MinReq.count(V) && NeedGraph.count(V)) { - if (auto CI = dyn_cast(V)) - if (getFuncNameFromCall(CI) == "julia.call") - assert(0); - - ValueToValueMapTy Available2; - for (auto a : Available) - Available2[a.first] = a.second; - for (Loop *L = OrigLI->getLoopFor(cast(V)->getParent()); - L != nullptr; L = L->getParentLoop()) { - for (auto v : LoopAvail[L]) { - Available2[v] = v; - } - } - assert(legalRecompute(V, Available2, nullptr)); - } - if (!NeedGraph.count(V)) { - assert(!MinReq.count(V)); - unnecessaryIntermediates.insert(cast(V)); - } - - if (NeedGraph.count(V) && MinReq.count(V)) { - CountTrackedPointers T(V->getType()); - assert(!T.derived); - } - } - } -} - -bool GradientUtils::isOriginalBlock(const BasicBlock &BB) const { - for (auto A : originalBlocks) { - if (A == &BB) - return true; - } - return false; -} - -void GradientUtils::eraseFictiousPHIs() { - { - for (auto P : rematerializedPrimalOrShadowAllocations) { - Value *replacement = - getUndefinedValueForType(*oldFunc->getParent(), P->getType()); - P->replaceAllUsesWith(replacement); - erase(P); - } - } - SmallVector, 4> phis; - for (auto pair : fictiousPHIs) - phis.emplace_back(pair.first, pair.second); - fictiousPHIs.clear(); - - for (auto pair : phis) { - auto pp = pair.first; - if (pp->getNumUses() != 0) { - if (CustomErrorHandler) { - std::string str; - raw_string_ostream ss(str); - ss << "Illegal replace ficticious phi for: " << *pp << " of " - << *pair.second; - CustomErrorHandler(str.c_str(), wrap(pair.second), - ErrorType::IllegalReplaceFicticiousPHIs, this, - wrap(pp), nullptr); - } else { - llvm::errs() << "mod:" << *oldFunc->getParent() << "\n"; - llvm::errs() << "oldFunc:" << *oldFunc << "\n"; - llvm::errs() << "newFunc:" << *newFunc << "\n"; - llvm::errs() << " pp: " << *pp << " of " << *pair.second << "\n"; - assert(pp->getNumUses() == 0); - } - } - pp->replaceAllUsesWith(UndefValue::get(pp->getType())); - erase(pp); - } -} - -void GradientUtils::forceActiveDetection() { - - TimeTraceScope timeScope("Activity Analysis", oldFunc->getName()); - - for (auto &Arg : oldFunc->args()) { - ATA->isConstantValue(TR, &Arg); - } - - for (BasicBlock &BB : *oldFunc) { - for (Instruction &I : BB) { - bool const_inst = ATA->isConstantInstruction(TR, &I); - bool const_value = ATA->isConstantValue(TR, &I); - if (EnzymePrintActivity) - llvm::errs() << I << " cv=" << const_value << " ci=" << const_inst - << "\n"; - } - } -} - -bool GradientUtils::isConstantValue(Value *val) const { - if (auto inst = dyn_cast(val)) { - (void)inst; - assert(inst->getParent()->getParent() == oldFunc); - return ATA->isConstantValue(TR, val); - } - - if (auto arg = dyn_cast(val)) { - (void)arg; - assert(arg->getParent() == oldFunc); - return ATA->isConstantValue(TR, val); - } - - //! Functions must be false so we can replace function with augmentation, - //! fallback to analysis - if (isa(val) || isa(val) || isa(val) || - isa(val) || isa(val)) { - // llvm::errs() << "calling icv on: " << *val << "\n"; - return ATA->isConstantValue(TR, val); - } - - if (auto gv = dyn_cast(val)) { - if (hasMetadata(gv, "enzyme_shadow")) - return false; - if (auto md = gv->getMetadata("enzyme_activity_value")) { - auto res = cast(md->getOperand(0))->getString(); - if (res == "const") - return true; - if (res == "active") - return false; - } - if (EnzymeNonmarkedGlobalsInactive) - return true; - goto err; - } - if (isa(val)) { - if (EnzymeNonmarkedGlobalsInactive) - return true; - goto err; - } - -err:; - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *val << "\n"; - llvm::errs() << " unknown did status attribute\n"; - assert(0 && "bad"); - exit(1); -} - -bool GradientUtils::isConstantInstruction(const Instruction *inst) const { - assert(inst->getParent()->getParent() == oldFunc); - return ATA->isConstantInstruction(TR, const_cast(inst)); -} - -bool GradientUtils::getContext(llvm::BasicBlock *BB, LoopContext &lc) { - return CacheUtility::getContext(BB, lc, - /*ReverseLimit*/ reverseBlocks.size() > 0); -} - -void GradientUtils::forceAugmentedReturns() { - assert(TR.getFunction() == oldFunc); - - for (BasicBlock &oBB : *oldFunc) { - // Don't create derivatives for code that results in termination - if (notForAnalysis.find(&oBB) != notForAnalysis.end()) - continue; - - LoopContext loopContext; - getContext(cast(getNewFromOriginal(&oBB)), loopContext); - - for (Instruction &I : oBB) { - Instruction *inst = &I; - - if (inst->getType()->isEmptyTy() || inst->getType()->isVoidTy()) - continue; - - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit || - mode == DerivativeMode::ForwardModeError) { - if (!isConstantValue(inst)) { - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - Type *antiTy = getShadowType(inst->getType()); - PHINode *anti = - BuilderZ.CreatePHI(antiTy, 1, inst->getName() + "'dual_phi"); - invertedPointers.insert(std::make_pair( - (const Value *)inst, InvertedPointerVH(this, anti))); - } - continue; - } - - if (inst->getType()->isFPOrFPVectorTy()) - continue; //! op->getType()->isPointerTy() && - //! !op->getType()->isIntegerTy()) { - - if (!TR.query(inst)[{-1}].isPossiblePointer()) - continue; - - if (isa(inst)) { - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - Type *antiTy = getShadowType(inst->getType()); - PHINode *anti = - BuilderZ.CreatePHI(antiTy, 1, inst->getName() + "'il_phi"); - invertedPointers.insert( - std::make_pair((const Value *)inst, InvertedPointerVH(this, anti))); - continue; - } - - if (!isa(inst)) { - continue; - } - - CallInst *op = cast(inst); - Function *called = op->getCalledFunction(); - - if ((mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ReverseModeCombined) && - called && called->getName() == "llvm.julia.gc_preserve_begin") { - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); - auto anti = BuilderZ.CreateCall(called, ArrayRef(), - op->getName() + "'ip"); - anti->setDebugLoc(getNewFromOriginal(op->getDebugLoc())); - invertedPointers.insert( - std::make_pair((const Value *)inst, InvertedPointerVH(this, anti))); - continue; - } - - if (isa(inst)) { - continue; - } - - if (isConstantValue(inst)) { - continue; - } - - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - - // Shadow allocations must strictly preceede the primal, lest Julia have - // GC issues. Consider the following: %r = gc_alloc() init %r - // ... - // if the shadow did not preceed - // %r = gc_alloc() - // %dr = gc_alloc() - // zero %dr - // init %r, %dr - // ... - // After %r, before %dr the %r memory would be uninit, so the allocator - // inside %dr would hit garbage and segfault. However, by having the %dr - // first, then it will be zero'd before the %r allocation, preventing the - // issue. - if (isAllocationCall(inst, TLI)) { - BuilderZ.SetInsertPoint(getNewFromOriginal(inst)); -#if LLVM_VERSION_MAJOR >= 18 - auto It = BuilderZ.GetInsertPoint(); - It.setHeadBit(true); - BuilderZ.SetInsertPoint(It); -#endif - } - Type *antiTy = getShadowType(inst->getType()); - - PHINode *anti = BuilderZ.CreatePHI(antiTy, 1, op->getName() + "'ip_phi"); - anti->setDebugLoc(getNewFromOriginal(op->getDebugLoc())); - invertedPointers.insert( - std::make_pair((const Value *)inst, InvertedPointerVH(this, anti))); - - if (isAllocationCall(inst, TLI)) { - anti->setName(op->getName() + "'mi"); - } - } - } -} - -void InvertedPointerVH::deleted() { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - gutils->dumpPointers(); - llvm::errs() << **this << "\n"; - assert(0 && "erasing something in invertedPointers map"); -} - -void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode, - Type *secretty, Intrinsic::ID intrinsic, - unsigned dstalign, unsigned srcalign, unsigned offset, - bool dstConstant, Value *shadow_dst, bool srcConstant, - Value *shadow_src, Value *length, Value *isVolatile, - llvm::CallInst *MTI, bool allowForward, - bool shadowsLookedUp, bool backwardsShadow) { - // TODO offset - if (secretty) { - // no change to forward pass if represents floats - if (mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ReverseModeCombined || - mode == DerivativeMode::ForwardModeSplit) { - IRBuilder<> Builder2(MTI); - if (mode == DerivativeMode::ForwardModeSplit) - gutils->getForwardBuilder(Builder2); - else - gutils->getReverseBuilder(Builder2); - - // If the src is constant simply zero d_dst and don't propagate to d_src - // (which thus == src and may be illegal) - if (srcConstant) { - // Don't zero in forward mode. - if (mode != DerivativeMode::ForwardModeSplit) { - - Value *args[] = { - shadowsLookedUp ? shadow_dst - : gutils->lookupM(shadow_dst, Builder2), - ConstantInt::get(Type::getInt8Ty(MTI->getContext()), 0), - gutils->lookupM(length, Builder2), - ConstantInt::getFalse(MTI->getContext())}; - - if (args[0]->getType()->isIntegerTy()) - args[0] = Builder2.CreateIntToPtr(args[0], - getInt8PtrTy(MTI->getContext())); - - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - auto memsetIntr = getIntrinsicDeclaration( - MTI->getParent()->getParent()->getParent(), Intrinsic::memset, - tys); - auto cal = Builder2.CreateCall(memsetIntr, args); - cal->setCallingConv(memsetIntr->getCallingConv()); - if (dstalign != 0) { - cal->addParamAttr(0, Attribute::getWithAlignment(MTI->getContext(), - Align(dstalign))); - } - } - - } else { - auto dsto = - (shadowsLookedUp || mode == DerivativeMode::ForwardModeSplit) - ? shadow_dst - : gutils->lookupM(shadow_dst, Builder2); - if (dsto->getType()->isIntegerTy()) - dsto = - Builder2.CreateIntToPtr(dsto, getInt8PtrTy(dsto->getContext())); - unsigned dstaddr = - cast(dsto->getType())->getAddressSpace(); - if (offset != 0) { - dsto = Builder2.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(dsto->getContext()), dsto, offset); - } - auto srco = - (shadowsLookedUp || mode == DerivativeMode::ForwardModeSplit) - ? shadow_src - : gutils->lookupM(shadow_src, Builder2); - if (mode != DerivativeMode::ForwardModeSplit) - dsto = Builder2.CreatePointerCast( - dsto, PointerType::get(secretty, dstaddr)); - if (srco->getType()->isIntegerTy()) - srco = - Builder2.CreateIntToPtr(srco, getInt8PtrTy(srco->getContext())); - unsigned srcaddr = - cast(srco->getType())->getAddressSpace(); - if (offset != 0) { - srco = Builder2.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(srco->getContext()), srco, offset); - } - if (mode != DerivativeMode::ForwardModeSplit) - srco = Builder2.CreatePointerCast( - srco, PointerType::get(secretty, srcaddr)); - - if (mode == DerivativeMode::ForwardModeSplit) { - MaybeAlign dalign; - if (dstalign) - dalign = MaybeAlign(dstalign); - MaybeAlign salign; - if (srcalign) - salign = MaybeAlign(srcalign); - - if (intrinsic == Intrinsic::memmove) { - Builder2.CreateMemMove(dsto, dalign, srco, salign, length); - } else { - Builder2.CreateMemCpy(dsto, dalign, srco, salign, length); - } - } else { - Value *args[]{ - Builder2.CreatePointerCast(dsto, - PointerType::get(secretty, dstaddr)), - Builder2.CreatePointerCast(srco, - PointerType::get(secretty, srcaddr)), - Builder2.CreateUDiv( - gutils->lookupM(length, Builder2), - ConstantInt::get(length->getType(), - Builder2.GetInsertBlock() - ->getParent() - ->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(secretty) / - 8))}; - - auto dmemcpy = ((intrinsic == Intrinsic::memcpy) - ? getOrInsertDifferentialFloatMemcpy - : getOrInsertDifferentialFloatMemmove)( - *MTI->getParent()->getParent()->getParent(), secretty, dstalign, - srcalign, dstaddr, srcaddr, - cast(length->getType())->getBitWidth()); - Builder2.CreateCall(dmemcpy, args); - } - } - } - } else { - - // if represents pointer or integer type then only need to modify forward - // pass with the copy - if ((allowForward && (mode == DerivativeMode::ReverseModePrimal || - mode == DerivativeMode::ReverseModeCombined)) || - (backwardsShadow && (mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ForwardModeSplit))) { - assert(!shadowsLookedUp); - - // It is questionable how the following case would even occur, but if - // the dst is constant, we shouldn't do anything extra - if (dstConstant) { - return; - } - - IRBuilder<> BuilderZ(gutils->getNewFromOriginal(MTI)); - - // If src is inactive, then we should copy from the regular pointer - // (i.e. suppose we are copying constant memory representing dimensions - // into a tensor) - // to ensure that the differential tensor is well formed for use - // OUTSIDE the derivative generation (as enzyme doesn't need this), we - // should also perform the copy onto the differential. Future - // Optimization (not implemented): If dst can never escape Enzyme code, - // we may omit this copy. - // no need to update pointers, even if dst is active - auto dsto = shadow_dst; - if (dsto->getType()->isIntegerTy()) - dsto = BuilderZ.CreateIntToPtr(dsto, getInt8PtrTy(MTI->getContext())); - if (offset != 0) { - dsto = BuilderZ.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(dsto->getContext()), dsto, offset); - } - auto srco = shadow_src; - if (srco->getType()->isIntegerTy()) - srco = BuilderZ.CreateIntToPtr(srco, getInt8PtrTy(MTI->getContext())); - if (offset != 0) { - srco = BuilderZ.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(srco->getContext()), srco, offset); - } - Value *args[] = {dsto, srco, length, isVolatile}; - - Type *tys[] = {args[0]->getType(), args[1]->getType(), - args[2]->getType()}; - - auto memtransIntr = - getIntrinsicDeclaration(gutils->newFunc->getParent(), intrinsic, tys); - auto cal = BuilderZ.CreateCall(memtransIntr, args); - cal->setAttributes(MTI->getAttributes()); - cal->setCallingConv(memtransIntr->getCallingConv()); - cal->setTailCallKind(MTI->getTailCallKind()); - - if (dstalign != 0) { - cal->addParamAttr( - 0, Attribute::getWithAlignment(MTI->getContext(), Align(dstalign))); - } - if (srcalign != 0) { - cal->addParamAttr( - 1, Attribute::getWithAlignment(MTI->getContext(), Align(srcalign))); - } - } - } -} - -void GradientUtils::computeForwardingProperties(Instruction *V) { - if (!EnzymeRematerialize) - return; - - // For the piece of memory V allocated within this scope, it will be - // initialized in some way by the (augmented) forward pass. Loads and other - // load-like operations will either require the allocation V itself to be - // preserved for the reverse pass, or alternatively the tape for those - // operations. - // - // Instead, we ask here whether or not we can restore the memory state of V in - // the reverse pass by recreating all of the stores and store-like operations - // into the V prior to their load-like uses. - // - // Notably, we only need to preserve the ability to reload any values actually - // used in the reverse pass. - - std::map Seen; - bool primalNeededInReverse = - DifferentialUseAnalysis::is_value_needed_in_reverse( - this, V, DerivativeMode::ReverseModeGradient, Seen, notForAnalysis); - - SmallVector loads; - SmallVector loadLikeCalls; - SmallPtrSet stores; - SmallPtrSet storingOps; - SmallPtrSet frees; - SmallPtrSet LifetimeStarts; - bool promotable = true; - bool shadowpromotable = true; - - CallInst *nonRepeatableWritingCall = nullptr; - SmallVector shadowPointerLoads; - - std::set> seen; - SmallVector, 1> todo; - for (auto U : V->users()) - if (auto I = dyn_cast(U)) - todo.push_back(std::make_pair(I, V)); - while (todo.size()) { - auto tup = todo.back(); - Instruction *cur = tup.first; - Value *prev = tup.second; - todo.pop_back(); - if (seen.count(tup)) - continue; - seen.insert(tup); - if (notForAnalysis.count(cur->getParent())) - continue; - if (isPointerArithmeticInst(cur)) { - for (auto u : cur->users()) { - if (auto I = dyn_cast(u)) - todo.push_back(std::make_pair(I, (Value *)cur)); - } - } else if (auto load = dyn_cast(cur)) { - - // If loaded value is an int or pointer, may need - // to preserve initialization within the primal. - auto TT = TR.query(load)[{-1}]; - if (!TT.isFloat()) { - shadowPointerLoads.push_back(cur); - } - loads.push_back(load); - } else if (auto store = dyn_cast(cur)) { - // TODO only add store to shadow iff non float type - if (store->getValueOperand() == prev) { - EmitWarning("NotPromotable", *cur, " Could not promote allocation ", *V, - " due to capturing store ", *cur); - promotable = false; - shadowpromotable = false; - break; - } else { - stores.insert(store); - storingOps.insert(store); - } - } else if (auto II = dyn_cast(cur)) { - switch (II->getIntrinsicID()) { - case Intrinsic::lifetime_start: - LifetimeStarts.insert(II); - break; - case Intrinsic::dbg_declare: - case Intrinsic::dbg_value: - case Intrinsic::dbg_label: -#if LLVM_VERSION_MAJOR <= 16 - case llvm::Intrinsic::dbg_addr: -#endif - case Intrinsic::lifetime_end: - break; - case Intrinsic::memset: { - stores.insert(II); - storingOps.insert(II); - break; - } - // TODO memtransfer(cpy/move) - case Intrinsic::memcpy: - case Intrinsic::memmove: - default: - promotable = false; - shadowpromotable = false; - EmitWarning("NotPromotable", *cur, " Could not promote allocation ", *V, - " due to unknown intrinsic ", *cur); - break; - } - } else if (auto CI = dyn_cast(cur)) { - StringRef funcName = getFuncNameFromCall(CI); - if (isDeallocationFunction(funcName, TLI)) { - frees.insert(CI); - continue; - } - // The allocation arg is the first arg of the write barrier. - // The capturing store in subsequent args should be handled by forbidding - // capturing stores - if (funcName == "julia.write_barrier" || - funcName == "julia.write_barrier_binding") { - if (CI->getArgOperand(0) == prev) { - stores.insert(CI); - } - continue; - } - if (funcName == "enzyme_zerotype") { - stores.insert(CI); - continue; - } - - size_t idx = 0; - bool seenLoadLikeCall = false; -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : CI->args()) -#else - for (auto &arg : CI->arg_operands()) -#endif - { - if (arg != prev) { - idx++; - continue; - } - auto TT = TR.query(prev)[{-1, -1}]; - - bool NoCapture = isNoCapture(CI, idx); - - bool ReadOnly = isReadOnly(CI, idx); - - bool WriteOnly = isWriteOnly(CI, idx); - - // If the pointer is captured, conservatively assume it is used in - // nontrivial ways that make both the primal and shadow not promotable. - if (!NoCapture) { - shadowpromotable = false; - promotable = false; - EmitWarning("NotPromotable", *cur, " Could not promote allocation ", - *V, " due to unknown capturing call ", *cur); - idx++; - continue; - } - - // From here on out we can assume the pointer is not captured, and only - // written to or read from. - - // If we may read from the memory, consider this a load-like call - // that must have all writes done in preparation for any reverse-pass - // users. - if (!WriteOnly) { - if (!seenLoadLikeCall) { - loadLikeCalls.push_back(LoadLikeCall(CI, prev)); - seenLoadLikeCall = true; - } - } - - // If we may write to memory, we cannot promote if any values - // need the allocation or any descendants for the reverse pass. - if (!ReadOnly) { - // There is an exception for Julia returnRoots which will be - // separately handled in a GC postprocessing pass. Moreover these - // values are never `needed` in the reverse pass (just we need to mark - // those values as being GC'd by the function). - bool returnRoots = - CI->getAttributes().hasParamAttr(idx, "enzymejl_returnRoots") || - CI->getAttributes().hasParamAttr(idx, "enzymejl_returnRoots_v"); - if (primalNeededInReverse && !returnRoots) { - promotable = false; - EmitWarning("NotPromotable", *cur, " Could not promote allocation ", - *V, " due to unknown writing call ", *cur); - } - if (!nonRepeatableWritingCall) - nonRepeatableWritingCall = CI; - storingOps.insert(cur); - } - - // Consider shadow memory now. - // - // If the memory is all floats, there's no issue, since besides zero - // initialization nothing should occur for them in the forward pass - if (TT.isFloat()) { - } else if (WriteOnly) { - // Don't need in the case of int/pointer stores, (should be done by - // fwd pass), and as isFloat above described does not prevent the - // shadow - } else { - shadowPointerLoads.push_back(cur); - } - - idx++; - } - - } else { - promotable = false; - shadowpromotable = false; - EmitWarning("NotPromotable", *cur, " Could not promote allocation ", *V, - " due to unknown instruction ", *cur); - } - } - - // Find the outermost loop of all stores, and the allocation/lifetime - Loop *outer = OrigLI->getLoopFor(V->getParent()); - if (LifetimeStarts.size() == 1) { - outer = OrigLI->getLoopFor((*LifetimeStarts.begin())->getParent()); - } - - for (auto S : stores) { - outer = getAncestor(outer, OrigLI->getLoopFor(S->getParent())); - } - - // May now read pointers for storing into other pointers. Therefore we - // need to pre initialize the shadow. - bool primalInitializationOfShadow = shadowPointerLoads.size() > 0; - - if (shadowpromotable && !isConstantValue(V)) { - for (auto LI : shadowPointerLoads) { - // Is there a store which could occur after the load. - // This subsequent store would invalidate any loads being re-performed. - SmallVector results; - mayExecuteAfter(results, LI, storingOps, outer); - for (auto res : results) { - if (overwritesToMemoryReadBy(&TR, *OrigAA, TLI, *OrigSE, *OrigLI, - *OrigDT, LI, res, outer)) { - EmitWarning("NotPromotable", *LI, - " Could not promote shadow allocation ", *V, - " due to pointer load ", *LI, - " which does not postdominates store ", *res); - shadowpromotable = false; - goto exitL; - } - } - } - // If there is a store not reproduced in the reverse pass (e.g. as part - // of a write in a call), and this store is necessary to a pointer load of - // the shadow, this is not materializable since the load will not return - // the same value. - { - SmallVector nonReproducedStores; - for (auto S : storingOps) - if (!stores.count(S)) { - SmallVector results; - SmallPtrSet shadowPtrLoadSet( - shadowPointerLoads.begin(), shadowPointerLoads.end()); - mayExecuteAfter(results, S, shadowPtrLoadSet, outer); - if (results.size()) { - EmitWarning("NotPromotable", *results[0], - " Could not promote shadow allocation ", *V, - " due to non-reproduced store ", *S, - " which may impact pointer load ", *results[0]); - shadowpromotable = false; - goto exitL; - } - } - } - exitL:; - if (shadowpromotable) { - backwardsOnlyShadows[V] = ShadowRematerializer( - stores, frees, primalInitializationOfShadow, outer); - } - } - - if (!promotable) - return; - - SmallPtrSet rematerializable; - - // We currently require a rematerializable allocation to have - // all of its loads be able to be performed again. Thus if - // there is an overwriting store after a load in context, - // it may no longer be rematerializable. - for (auto LI : loads) { - // Is there a store which could occur after the load. - // In other words - SmallVector results; - mayExecuteAfter(results, LI, storingOps, outer); - for (auto res : results) { - if (overwritesToMemoryReadBy(&TR, *OrigAA, TLI, *OrigSE, *OrigLI, *OrigDT, - LI, res, outer)) { - EmitWarning("NotPromotable", *LI, " Could not promote allocation ", *V, - " due to load ", *LI, - " which does not postdominates store ", *res); - return; - } - } - rematerializable.insert(LI); - } - for (auto LI : loadLikeCalls) { - // Is there a store which could occur after the load. - // In other words - SmallVector results; - mayExecuteAfter(results, LI.loadCall, storingOps, outer); - for (auto res : results) { - if (overwritesToMemoryReadBy(&TR, *OrigAA, TLI, *OrigSE, *OrigLI, *OrigDT, - LI.loadCall, res, outer)) { - EmitWarning("NotPromotable", *LI.loadCall, - " Could not promote allocation ", *V, - " due to load-like call ", *LI.loadCall, - " which does not postdominates store ", *res); - return; - } - } - } - rematerializableAllocations[V] = Rematerializer( - loads, loadLikeCalls, stores, frees, outer, nonRepeatableWritingCall); -} - -BasicBlock *GradientUtils::addReverseBlock(BasicBlock *currentBlock, - Twine const &name, bool forkCache, - bool push) { - assert(reverseBlocks.size()); - auto found = reverseBlockToPrimal.find(currentBlock); - assert(found != reverseBlockToPrimal.end()); - - SmallVector &vec = reverseBlocks[found->second]; - assert(vec.size()); - assert(vec.back() == currentBlock); - - BasicBlock *rev = - BasicBlock::Create(currentBlock->getContext(), name, newFunc); - rev->moveAfter(currentBlock); - if (push) - vec.push_back(rev); - reverseBlockToPrimal[rev] = found->second; - if (forkCache) { - for (auto pair : unwrap_cache[currentBlock]) - unwrap_cache[rev].insert(pair); - for (auto pair : lookup_cache[currentBlock]) - lookup_cache[rev].insert(pair); - } - return rev; -} - -void GradientUtils::replaceAWithB(Value *A, Value *B, bool storeInCache) { - if (A == B) - return; - assert(A->getType() == B->getType()); - - if (auto iA = dyn_cast(A)) { - if (unwrappedLoads.find(iA) != unwrappedLoads.end()) { - auto iB = cast(B); - unwrappedLoads[iB] = unwrappedLoads[iA]; - unwrappedLoads.erase(iA); - } - } - - // Check that the replacement doesn't already exist in the mapping - // thereby resulting in a conflict. -#ifndef NDEBUG - if (!isa(B)) { - auto found = newToOriginalFn.find(A); - if (found != newToOriginalFn.end()) { - auto foundB = newToOriginalFn.find(B); - assert(foundB == newToOriginalFn.end()); - } - } -#endif - - CacheUtility::replaceAWithB(A, B, storeInCache); -} - -void GradientUtils::erase(Instruction *I) { - assert(I); - if (I->getParent()->getParent() != newFunc) { - llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << "paren: " << *I->getParent()->getParent() << "\n"; - llvm::errs() << "I: " << *I << "\n"; - } - assert(I->getParent()->getParent() == newFunc); - - // not original, should not contain - assert(!invertedPointers.count(I)); - // not original, should not contain - assert(!originalToNewFn.count(I)); - - originalToNewFn.erase(I); - { - auto found = newToOriginalFn.find(I); - if (found != newToOriginalFn.end()) { - Value *orig = found->second; - newToOriginalFn.erase(found); - originalToNewFn.erase(orig); - } - } - { - auto found = UnwrappedWarnings.find(I); - if (found != UnwrappedWarnings.end()) { - UnwrappedWarnings.erase(found); - } - } - unwrappedLoads.erase(I); - - for (auto &pair : unwrap_cache) { - if (pair.second.find(I) != pair.second.end()) - pair.second.erase(I); - } - - for (auto &pair : lookup_cache) { - if (pair.second.find(I) != pair.second.end()) - pair.second.erase(I); - } - CacheUtility::erase(I); -} - -void GradientUtils::eraseWithPlaceholder(Instruction *I, Instruction *orig, - const Twine &suffix, bool erase) { - if (!I->getType()->isVoidTy() && !I->getType()->isTokenTy()) { - auto inspos = I->getIterator(); -#if LLVM_VERSION_MAJOR >= 18 - if (I->getParent()->IsNewDbgInfoFormat) { - if (!inspos.getHeadBit()) { - auto srcmarker = I->getParent()->getMarker(inspos); - if (srcmarker && !srcmarker->empty()) { - inspos.setHeadBit(true); - } - } - } -#endif - IRBuilder<> BuilderZ(I->getParent(), inspos); - auto pn = BuilderZ.CreatePHI(I->getType(), 1, I->getName() + suffix); - fictiousPHIs[pn] = orig; - replaceAWithB(I, pn); - } - - if (erase) { - this->erase(I); - } -} - -void GradientUtils::setTape(Value *newtape) { - assert(tape == nullptr); - assert(newtape != nullptr); - assert(tapeidx == 0); - assert(addedTapeVals.size() == 0); - tape = newtape; -} - -void GradientUtils::dumpPointers() { - errs() << "invertedPointers:\n"; - for (auto a : invertedPointers) { - errs() << " invertedPointers[" << *a.first << "] = " << *a.second << "\n"; - } - errs() << "end invertedPointers\n"; -} - -int GradientUtils::getIndex( - std::pair idx, - const std::map, int> &mapping, - IRBuilder<> &B) { - assert(tape); - auto found = mapping.find(idx); - if (found == mapping.end()) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << *oldFunc << "\n"; - ss << *newFunc << "\n"; - ss << " \n"; - for (auto &p : mapping) { - ss << " idx: " << *p.first.first << ", " << p.first.second - << " pos=" << p.second << "\n"; - } - ss << " \n"; - ss << "idx: " << *idx.first << ", " << idx.second << "\n"; - ss << " could not find index in mapping\n"; - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(idx.first), - ErrorType::GetIndexError, this, nullptr, wrap(&B)); - } else { - EmitFailure("GetIndexError", idx.first->getDebugLoc(), idx.first, - ss.str()); - } - return IndexMappingError; - } - return found->second; -} - -int GradientUtils::getIndex( - std::pair idx, - std::map, int> &mapping, - IRBuilder<> &B) { - if (tape) { - return getIndex( - idx, - (const std::map, int> &)mapping, B); - } else { - if (mapping.find(idx) != mapping.end()) { - return mapping[idx]; - } - mapping[idx] = tapeidx; - ++tapeidx; - return mapping[idx]; - } -} - -void GradientUtils::computeGuaranteedFrees() { - SmallPtrSet allocsToPromote; - for (auto &BB : *oldFunc) { - if (notForAnalysis.count(&BB)) - continue; - for (auto &I : BB) { - if (auto AI = dyn_cast(&I)) - computeForwardingProperties(AI); - - auto CI = dyn_cast(&I); - if (!CI) - continue; - - StringRef funcName = getFuncNameFromCall(CI); - - if (isDeallocationFunction(funcName, TLI)) { - llvm::Value *val = getBaseObject(CI->getArgOperand(0)); - - if (auto dc = dyn_cast(val)) { - StringRef sfuncName = getFuncNameFromCall(dc); - if (isAllocationFunction(sfuncName, TLI)) { - - bool hasPDFree = false; - if (dc->getParent() == CI->getParent() || - OrigPDT->dominates(CI->getParent(), dc->getParent())) { - hasPDFree = true; - } - - if (hasPDFree) { - allocationsWithGuaranteedFree[dc].insert(CI); - } - } - } - } - if (isAllocationFunction(funcName, TLI)) { - allocsToPromote.insert(CI); - if (hasMetadata(CI, "enzyme_fromstack")) { - allocationsWithGuaranteedFree[CI].insert(CI); - } - // TODO: special case object managed by the GC as it is automatically - // freed. - if (EnzymeJuliaAddrLoad && isa(CI->getType()) && - cast(CI->getType())->getAddressSpace() == 10) { - } - } - } - } - for (CallInst *V : allocsToPromote) { - // TODO compute if an only load/store (non capture) - // allocaion by traversing its users. If so, mark - // all of its load/stores, as now the loads can - // potentially be rematerialized without a cache - // of the allocation, but the operands of all stores. - // This info needs to be provided to minCutCache - // the derivative of store needs to redo the store, - // isValueNeededInReverse needs to know to preserve the - // store operands in this case, etc - computeForwardingProperties(V); - } -} - -/// Perform the corresponding deallocation of tofree, given it was allocated by -/// allocationfn -// For updating below one should read MemoryBuiltins.cpp, TargetLibraryInfo.cpp -llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, - llvm::Value *tofree, - llvm::StringRef allocationfn, - const llvm::DebugLoc &debuglocation, - const llvm::TargetLibraryInfo &TLI, - llvm::CallInst *orig, - GradientUtils *gutils) { - assert(isAllocationFunction(allocationfn, TLI)); - - if (allocationfn == "__rust_alloc" || allocationfn == "__rust_alloc_zeroed") { - Type *VoidTy = Type::getVoidTy(tofree->getContext()); - Type *IntPtrTy = orig->getType(); - Type *RustSz = orig->getArgOperand(0)->getType(); - Type *inTys[3] = {IntPtrTy, RustSz, RustSz}; - - auto FT = FunctionType::get(VoidTy, inTys, false); - Value *freevalue = builder.GetInsertBlock() - ->getParent() - ->getParent() - ->getOrInsertFunction("__rust_dealloc", FT) - .getCallee(); - Value *vals[3]; - vals[0] = builder.CreatePointerCast(tofree, IntPtrTy); - // size - vals[1] = gutils->lookupM( - gutils->getNewFromOriginal(orig->getArgOperand(0)), builder); - // alignment - vals[2] = gutils->lookupM( - gutils->getNewFromOriginal(orig->getArgOperand(1)), builder); - CallInst *freecall = cast( - CallInst::Create(FT, freevalue, vals, "", builder.GetInsertBlock())); - freecall->setDebugLoc(debuglocation); - if (isa(tofree) && - cast(tofree)->getAttributes().hasAttribute( - AttributeList::ReturnIndex, Attribute::NonNull)) { - freecall->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); - } - if (Function *F = dyn_cast(freevalue)) - freecall->setCallingConv(F->getCallingConv()); - if (freecall->getParent() == nullptr) - builder.Insert(freecall); - return freecall; - } - if (allocationfn == "julia.gc_alloc_obj" || - allocationfn == "jl_gc_alloc_typed" || - allocationfn == "ijl_gc_alloc_typed" || - allocationfn == "jl_alloc_array_1d" || - allocationfn == "ijl_alloc_array_1d" || - allocationfn == "jl_alloc_array_2d" || - allocationfn == "ijl_alloc_array_2d" || - allocationfn == "jl_alloc_array_3d" || - allocationfn == "ijl_alloc_array_3d" || allocationfn == "jl_new_array" || - allocationfn == "ijl_new_array" || - allocationfn == "jl_alloc_genericmemory" || - allocationfn == "ijl_alloc_genericmemory") - return nullptr; - - if (allocationfn == "enzyme_allocator") { - auto inds = getDeallocationIndicesFromCall(orig); - SmallVector vals; - for (auto ind : inds) { - if (ind == -1) - vals.push_back(tofree); - else - vals.push_back(gutils->lookupM( - gutils->getNewFromOriginal(orig->getArgOperand(ind)), builder)); - } - auto tocall = getDeallocatorFnFromCall(orig); - auto freecall = builder.CreateCall(tocall, vals); - freecall->setDebugLoc(debuglocation); - return freecall; - } - - if (allocationfn == "swift_allocObject") { - Type *VoidTy = Type::getVoidTy(tofree->getContext()); - Type *IntPtrTy = getInt8PtrTy(tofree->getContext()); - - auto FT = FunctionType::get(VoidTy, ArrayRef(IntPtrTy), false); - Value *freevalue = builder.GetInsertBlock() - ->getParent() - ->getParent() - ->getOrInsertFunction("swift_release", FT) - .getCallee(); - CallInst *freecall = cast(CallInst::Create( - FT, freevalue, - ArrayRef(builder.CreatePointerCast(tofree, IntPtrTy)), "", - builder.GetInsertBlock())); - freecall->setDebugLoc(debuglocation); - if (isa(tofree) && - cast(tofree)->getAttributes().hasAttribute( - AttributeList::ReturnIndex, Attribute::NonNull)) { - freecall->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); - } - if (Function *F = dyn_cast(freevalue)) - freecall->setCallingConv(F->getCallingConv()); - if (freecall->getParent() == nullptr) - builder.Insert(freecall); - return freecall; - } - - if (shadowErasers.find(allocationfn) != shadowErasers.end()) { - return shadowErasers[allocationfn](builder, tofree); - } - - if (allocationfn == "__size_returning_new_experiment") { - allocationfn = "malloc"; - tofree = builder.CreateExtractValue(tofree, 0); - } - - if (tofree->getType()->isIntegerTy()) - tofree = builder.CreateIntToPtr(tofree, getInt8PtrTy(tofree->getContext())); - - llvm::LibFunc libfunc; - if (allocationfn == "calloc" || allocationfn == "malloc" || - allocationfn == "_mlir_memref_to_llvm_alloc") { - libfunc = LibFunc_malloc; - } else { - bool res = TLI.getLibFunc(allocationfn, libfunc); - (void)res; - assert(res && "ought find known allocation fn"); - } - - llvm::LibFunc freefunc; - - switch (libfunc) { - case LibFunc_malloc: // malloc(unsigned int); - case LibFunc_valloc: // valloc(unsigned int); - freefunc = LibFunc_free; - break; - - case LibFunc_Znwj: // new(unsigned int); - case LibFunc_ZnwjRKSt9nothrow_t: // new(unsigned int, nothrow); - case LibFunc_ZnwjSt11align_val_t: // new(unsigned int, align_val_t) - case LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t: // new(unsigned int, - // align_val_t, nothrow) - - case LibFunc_Znwm: // new(unsigned long); - case LibFunc_ZnwmRKSt9nothrow_t: // new(unsigned long, nothrow); - case LibFunc_ZnwmSt11align_val_t: // new(unsigned long, align_val_t) - case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: // new(unsigned long, - // align_val_t, nothrow) - freefunc = LibFunc_ZdlPv; - break; - - case LibFunc_Znaj: // new[](unsigned int); - case LibFunc_ZnajRKSt9nothrow_t: // new[](unsigned int, nothrow); - case LibFunc_ZnajSt11align_val_t: // new[](unsigned int, align_val_t) - case LibFunc_ZnajSt11align_val_tRKSt9nothrow_t: // new[](unsigned int, - // align_val_t, nothrow - - case LibFunc_Znam: // new[](unsigned long); - case LibFunc_ZnamRKSt9nothrow_t: // new[](unsigned long, nothrow); - case LibFunc_ZnamSt11align_val_t: // new[](unsigned long, align_val_t) - case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: // new[](unsigned long, - // align_val_t, nothrow) - freefunc = LibFunc_ZdaPv; - break; - - case LibFunc_msvc_new_longlong: // new(unsigned long long); - case LibFunc_msvc_new_longlong_nothrow: // new(unsigned long long, nothrow); - freefunc = LibFunc_msvc_delete_ptr64_longlong; - break; - - case LibFunc_msvc_new_array_longlong: // new[](unsigned long long); - case LibFunc_msvc_new_array_longlong_nothrow: // new[](unsigned long long, - // nothrow); - freefunc = LibFunc_msvc_delete_array_ptr64_longlong; - break; - - case LibFunc_msvc_new_int: // new(unsigned int); - case LibFunc_msvc_new_int_nothrow: // new(unsigned int, nothrow); - case LibFunc_msvc_new_array_int: // new[](unsigned int); - case LibFunc_msvc_new_array_int_nothrow: // new[](unsigned int, nothrow); - llvm_unreachable("msvc deletion not handled"); - - default: - llvm_unreachable("unknown allocation function"); - } - llvm::StringRef freename = TLI.getName(freefunc); - if (freefunc == LibFunc_free) { - freename = "free"; - assert(freename == "free"); - if (freename != "free") - llvm_unreachable("illegal free"); - } - if (allocationfn == "_mlir_memref_to_llvm_alloc") - freename = "_mlir_memref_to_llvm_free"; - - Type *VoidTy = Type::getVoidTy(tofree->getContext()); - Type *IntPtrTy = getInt8PtrTy(tofree->getContext()); - - auto FT = FunctionType::get(VoidTy, {IntPtrTy}, false); - Value *freevalue = builder.GetInsertBlock() - ->getParent() - ->getParent() - ->getOrInsertFunction(freename, FT) - .getCallee(); - CallInst *freecall = cast(CallInst::Create( - FT, freevalue, {builder.CreatePointerCast(tofree, IntPtrTy)}, "", - builder.GetInsertBlock())); - freecall->setDebugLoc(debuglocation); - if (isa(tofree) && - cast(tofree)->getAttributes().hasAttribute( - AttributeList::ReturnIndex, Attribute::NonNull)) { - freecall->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); - } - if (Function *F = dyn_cast(freevalue)) - freecall->setCallingConv(F->getCallingConv()); - if (freecall->getParent() == nullptr) - builder.Insert(freecall); - return freecall; -} - -bool GradientUtils::needsCacheWholeAllocation( - const llvm::Value *origInst) const { - auto found = knownRecomputeHeuristic.find(origInst); - if (found == knownRecomputeHeuristic.end()) - return false; - if (!found->second) - return true; - // User, operand of input, whehter the input is the original allocation - SmallVector, 1> todo; - for (auto &use : origInst->uses()) - todo.push_back(std::make_tuple(cast(use.getUser()), - use.getOperandNo(), true)); - SmallSet, 1> seen; - while (todo.size()) { - auto pair = todo.back(); - auto [cur, idx, orig] = pair; - todo.pop_back(); - if (seen.count(pair)) - continue; - seen.insert(pair); - // Loads are always fine - if (isa(cur) || isNVLoad(cur)) - continue; - - if (auto II = dyn_cast(cur)) - if (II->getIntrinsicID() == Intrinsic::masked_load) - continue; - - bool returnedSameValue = false; - - if (auto CI = dyn_cast(cur)) { -#if LLVM_VERSION_MAJOR >= 14 - if (idx < CI->arg_size()) -#else - if (idx < CI->getNumArgOperands()) -#endif - { - - // Calling a non-empty function with a julia base object, this is fine. - // as GC will deal with any issues with. - if (auto PT = dyn_cast(CI->getArgOperand(idx)->getType())) - if (PT->getAddressSpace() == 10) - if (EnzymeJuliaAddrLoad) - if (auto F = getFunctionFromCall(CI)) - if (!F->empty()) - continue; - - if (isNoCapture(CI, idx)) - continue; - - if (auto F = CI->getCalledFunction()) - if (F->getCallingConv() == CI->getCallingConv() && !F->empty()) { - bool onlyReturnUses = true; - bool hasReturnUse = true; - - if (CI->getFunctionType() != F->getFunctionType() || - idx >= F->getFunctionType()->getNumParams()) { - onlyReturnUses = false; - } else { - for (auto u : F->getArg(idx)->users()) { - if (isa(u)) { - hasReturnUse = true; - continue; - } - onlyReturnUses = false; - continue; - } - } - // The arg itself has no use in the function - if (onlyReturnUses && !hasReturnUse) - continue; - - // If this is the original allocation, we return it guaranteed, and - // cache the return, that's still fine - if (onlyReturnUses && orig) { - found = knownRecomputeHeuristic.find(cur); - if (found == knownRecomputeHeuristic.end()) - continue; - - if (!found->second) - continue; - returnedSameValue = true; - } - } - } - } - - found = knownRecomputeHeuristic.find(cur); - if (found == knownRecomputeHeuristic.end()) - continue; - - // If caching a julia base object, this is fine as - // GC will deal with any issues with. - if (auto PT = dyn_cast(cur->getType())) - if (PT->getAddressSpace() == 10) - if (EnzymeJuliaAddrLoad) - continue; - - // If caching this user, it cannot be a gep/cast of original - if (!found->second) { - llvm::errs() << " mod: " << *oldFunc->getParent() << "\n"; - llvm::errs() << " oldFunc: " << *oldFunc << "\n"; - for (auto &pair : knownRecomputeHeuristic) - llvm::errs() << " krc[" << *pair.first << "] = " << pair.second << "\n"; - llvm::errs() << " cur: " << *cur << "\n"; - llvm::errs() << " origInst: " << *origInst << "\n"; - assert(false && "caching potentially capturing/offset of allocation"); - } else { - // if not caching this user, it is legal to recompute, consider its users - for (auto &use : cur->uses()) { - todo.push_back(std::make_tuple(cast(use.getUser()), - use.getOperandNo(), - returnedSameValue && orig)); - } - } - } - return false; -} - -void GradientUtils::replaceAndRemoveUnwrapCacheFor(llvm::Value *A, - llvm::Value *B) { - SmallVector toErase; - for (auto &pair : unwrap_cache) { - auto found = pair.second.find(A); - if (found != pair.second.end()) { - for (auto &p : found->second) { - Value *pre = p.second; - replaceAWithB(pre, B); - if (auto I = dyn_cast(pre)) { - toErase.push_back(I); - } - } - pair.second.erase(A); - } - } - for (auto I : toErase) { - erase(I); - } -} diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h deleted file mode 100644 index 7475ba7ff186..000000000000 --- a/enzyme/Enzyme/GradientUtils.h +++ /dev/null @@ -1,661 +0,0 @@ -//===- GradientUtils.h - Helper class and utilities for AD ---------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file declares two helper classes GradientUtils and subclass -// DiffeGradientUtils. These classes contain utilities for managing the cache, -// recomputing statements, and in the case of DiffeGradientUtils, managing -// adjoint values and shadow pointers. -// -//===----------------------------------------------------------------------===// - -#ifndef ENZYME_GUTILS_H_ -#define ENZYME_GUTILS_H_ - -#include -#include -#include - -#include - -#if LLVM_VERSION_MAJOR >= 16 -#define private public -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#undef private -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/PostDominators.h" -#include "llvm/Analysis/ValueTracking.h" - -#include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" - -#include "ActivityAnalysis.h" -#include "CacheUtility.h" -#include "EnzymeLogic.h" -#include "Utils.h" - -#include "llvm-c/Core.h" - -class GradientUtils; -extern llvm::StringMap &, llvm::CallInst *, llvm::ArrayRef, - GradientUtils *)>> - shadowHandlers; - -class DiffeGradientUtils; -extern llvm::StringMap &, llvm::CallInst *, GradientUtils &, - llvm::Value *&, llvm::Value *&, llvm::Value *&)>, - std::function &, llvm::CallInst *, - DiffeGradientUtils &, llvm::Value *)>>> - customCallHandlers; - -extern llvm::StringMap< - std::function &, llvm::CallInst *, GradientUtils &, - llvm::Value *&, llvm::Value *&)>> - customFwdCallHandlers; - -constexpr int IndexMappingError = 0x0000fffd; - -extern "C" { -extern llvm::cl::opt EnzymeInactiveDynamic; -extern llvm::cl::opt EnzymeFreeInternalAllocations; -extern llvm::cl::opt EnzymeRematerialize; -} -extern llvm::SmallVector MD_ToCopy; - -struct InvertedPointerConfig : llvm::ValueMapConfig { - typedef GradientUtils *ExtraData; - static void onDelete(ExtraData gutils, const llvm::Value *old); -}; - -class InvertedPointerVH final : public llvm::CallbackVH { -public: - GradientUtils *gutils; - InvertedPointerVH(GradientUtils *gutils) : gutils(gutils) {} - InvertedPointerVH(GradientUtils *gutils, llvm::Value *V) - : InvertedPointerVH(gutils) { - setValPtr(V); - } - void deleted() override final; - - void allUsesReplacedWith(llvm::Value *new_value) override final { - setValPtr(new_value); - } - virtual ~InvertedPointerVH() {} -}; - -enum class AugmentedStruct; -class GradientUtils : public CacheUtility { -public: - EnzymeLogic &Logic; - bool AtomicAdd; - DerivativeMode mode; - llvm::Function *oldFunc; - llvm::ValueMap invertedPointers; - llvm::DominatorTree *OrigDT; - llvm::PostDominatorTree *OrigPDT; - llvm::LoopInfo *OrigLI; - llvm::ScalarEvolution *OrigSE; - - /// (Original) Blocks which dominate all returns - llvm::SmallPtrSet BlocksDominatingAllReturns; - - llvm::SmallPtrSet notForAnalysis; - std::shared_ptr ATA; - llvm::SmallVector originalBlocks; - - /// Allocations which are known to always be freed before the - /// reverse, to the list of frees that must apply to this allocation. - llvm::ValueMap> - allocationsWithGuaranteedFree; - - /// Frees which can always be eliminated as the post dominate - /// an allocation (which will itself be freed). - llvm::SmallPtrSet postDominatingFrees; - - /// Deallocations that should be kept in the forward pass because - /// they deallocation memory which isn't necessary for the reverse - /// pass - llvm::SmallPtrSet forwardDeallocations; - - /// Map of primal block to corresponding block(s) in reverse - std::map> - reverseBlocks; - /// Map of block in reverse to corresponding primal block - std::map reverseBlockToPrimal; - - /// A set of tape extractions to enforce a cache of - /// rather than attempting to recompute. - llvm::SmallPtrSet TapesToPreventRecomputation; - - llvm::ValueMap fictiousPHIs; - llvm::ValueMap originalToNewFn; - llvm::ValueMap newToOriginalFn; - llvm::SmallVector originalCalls; - - llvm::SmallPtrSet unnecessaryIntermediates; - - const std::map *can_modref_map; - const std::map>> - *overwritten_args_map_ptr; - const llvm::SmallPtrSetImpl *unnecessaryValuesP; - - llvm::SmallVector getInvertedBundles( - llvm::CallInst *orig, llvm::ArrayRef types, - llvm::IRBuilder<> &Builder2, bool lookup, - const llvm::ValueToValueMapTy &available = llvm::ValueToValueMapTy()); - - bool usedInRooting(const llvm::CallBase *orig, - llvm::ArrayRef types, const llvm::Value *val, - bool shadow) const; - - llvm::Value *getNewIfOriginal(llvm::Value *originst) const; - - llvm::Value *tid; - llvm::Value *ompThreadId(); - - llvm::Value *numThreads; - llvm::Value *ompNumThreads(); - - llvm::Value *getOrInsertTotalMultiplicativeProduct(llvm::Value *val, - LoopContext &lc); - - llvm::Value *getOrInsertConditionalIndex(llvm::Value *val, LoopContext &lc, - bool pickTrue); - - bool assumeDynamicLoopOfSizeOne(llvm::Loop *L) const override; - - llvm::DebugLoc getNewFromOriginal(const llvm::DebugLoc L) const; - - llvm::Value *getNewFromOriginal(const llvm::Value *originst) const; - - llvm::Instruction *getNewFromOriginal(const llvm::Instruction *newinst) const; - - llvm::BasicBlock *getNewFromOriginal(const llvm::BasicBlock *newinst) const; - - llvm::Value *hasUninverted(const llvm::Value *inverted) const; - - llvm::BasicBlock *getOriginalFromNew(const llvm::BasicBlock *newinst) const; - - llvm::Value *isOriginal(const llvm::Value *newinst) const; - - llvm::Instruction *isOriginal(const llvm::Instruction *newinst) const; - - llvm::BasicBlock *isOriginal(const llvm::BasicBlock *newinst) const; - - struct LoadLikeCall { - llvm::CallInst *loadCall; - llvm::Value *operand; - LoadLikeCall() = default; - LoadLikeCall(llvm::CallInst *a, llvm::Value *b) : loadCall(a), operand(b) {} - }; - - struct Rematerializer { - // Loads which may need to be rematerialized. - llvm::SmallVector loads; - - // Loads-like calls which need the memory initialized for the reverse. - llvm::SmallVector loadLikeCalls; - - // Operations which must be rerun to rematerialize - // the value. - llvm::SmallPtrSet stores; - - // Operations which deallocate the value. - llvm::SmallPtrSet frees; - - // Loop scope (null if not loop scoped). - llvm::Loop *LI; - - // If non-null, a call which writes to the value which cannot be reproduced - // in the reverse pass. If any values of this allocation are needed in the - // reverse pass and this is non-null, this allocation cannot be - // rematerialized. - llvm::CallInst *nonRepeatableWritingCall; - - Rematerializer() : loads(), stores(), frees(), LI(nullptr) {} - Rematerializer(llvm::ArrayRef loads, - llvm::ArrayRef loadLikeCalls, - const llvm::SmallPtrSetImpl &stores, - const llvm::SmallPtrSetImpl &frees, - llvm::Loop *LI, llvm::CallInst *nonRepeatableWritingCall) - : loads(loads.begin(), loads.end()), - loadLikeCalls(loadLikeCalls.begin(), loadLikeCalls.end()), - stores(stores.begin(), stores.end()), - frees(frees.begin(), frees.end()), LI(LI), - nonRepeatableWritingCall(nonRepeatableWritingCall) {} - }; - - struct ShadowRematerializer { - /// Operations which must be rerun to rematerialize - /// the original value. - llvm::SmallPtrSet stores; - - /// Operations which deallocate the value. - llvm::SmallPtrSet frees; - - /// Whether the shadow must be initialized in the primal. - bool primalInitialize; - - /// Loop scope (null if not loop scoped). - llvm::Loop *LI; - - ShadowRematerializer() - : stores(), frees(), primalInitialize(), LI(nullptr) {} - ShadowRematerializer( - const llvm::SmallPtrSetImpl &stores, - const llvm::SmallPtrSetImpl &frees, - bool primalInitialize, llvm::Loop *LI) - : stores(stores.begin(), stores.end()), - frees(frees.begin(), frees.end()), primalInitialize(primalInitialize), - LI(LI) {} - }; - - llvm::ValueMap rematerializableAllocations; - - /// Only loaded from and stored to (not captured), mapped to the stores (and - /// memset). Boolean denotes whether the primal initializes the shadow as well - /// (for use) as a structure which carries data. - llvm::ValueMap backwardsOnlyShadows; - - void computeForwardingProperties(llvm::Instruction *V); - void computeGuaranteedFrees(); - -private: - llvm::SmallVector addedTapeVals; - unsigned tapeidx; - llvm::Value *tape; - - std::map>> - unwrap_cache; - std::map> - lookup_cache; - -public: - void replaceAndRemoveUnwrapCacheFor(llvm::Value *A, llvm::Value *B); - - llvm::BasicBlock *addReverseBlock(llvm::BasicBlock *currentBlock, - llvm::Twine const &name, - bool forkCache = true, bool push = true); - - bool legalRecompute(const llvm::Value *val, - const llvm::ValueToValueMapTy &available, - llvm::IRBuilder<> *BuilderM, bool reverse = false, - bool legalRecomputeCache = true) const; - - std::map knownRecomputeHeuristic; - bool shouldRecompute(const llvm::Value *val, - const llvm::ValueToValueMapTy &available, - llvm::IRBuilder<> *BuilderM); - - llvm::ValueMap - unwrappedLoads; - - void replaceAWithB(llvm::Value *A, llvm::Value *B, - bool storeInCache = false) override; - - void erase(llvm::Instruction *I) override; - - void eraseWithPlaceholder(llvm::Instruction *I, llvm::Instruction *orig, - const llvm::Twine &suffix = "_replacementA", - bool erase = true); - - // TODO consider invariant group and/or valueInvariant group - - void setTape(llvm::Value *newtape); - - void dumpPointers(); - - int getIndex( - std::pair idx, - const std::map, int> &mapping, - llvm::IRBuilder<> &); - - int getIndex( - std::pair idx, - std::map, int> &mapping, - llvm::IRBuilder<> &); - - llvm::Value *cacheForReverse(llvm::IRBuilder<> &BuilderQ, llvm::Value *malloc, - int idx, bool replace = true); - - llvm::ArrayRef getTapeValues() const { - return addedTapeVals; - } - -public: - llvm::AAResults *OrigAA; - TypeAnalysis &TA; - TypeResults TR; - bool omp; - bool runtimeActivity; - -private: - unsigned width; - -public: - unsigned getWidth() { return width; } - - bool shadowReturnUsed; - - llvm::ArrayRef ArgDiffeTypes; - -public: - GradientUtils(EnzymeLogic &Logic, llvm::Function *newFunc_, - llvm::Function *oldFunc_, llvm::TargetLibraryInfo &TLI_, - TypeAnalysis &TA_, TypeResults TR_, - llvm::ValueToValueMapTy &invertedPointers_, - const llvm::SmallPtrSetImpl &constantvalues_, - const llvm::SmallPtrSetImpl &activevals_, - DIFFE_TYPE ReturnActivity, bool shadowReturnUsed, - llvm::ArrayRef ArgDiffeTypes_, - llvm::ValueMap - &originalToNewFn_, - DerivativeMode mode, bool runtimeActivity, unsigned width, - bool omp); - -public: - DIFFE_TYPE getDiffeType(llvm::Value *v, bool foreignFunction) const; - - DIFFE_TYPE getReturnDiffeType(llvm::Value *orig, bool *primalReturnUsedP, - bool *shadowReturnUsedP, - DerivativeMode cmode) const; - - DIFFE_TYPE getReturnDiffeType(llvm::Value *orig, bool *primalReturnUsedP, - bool *shadowReturnUsedP) const; - - static GradientUtils * - CreateFromClone(EnzymeLogic &Logic, bool runtimeActivity, unsigned width, - llvm::Function *todiff, llvm::TargetLibraryInfo &TLI, - TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, - llvm::ArrayRef constant_args, bool returnUsed, - bool shadowReturnUsed, - std::map &returnMapping, bool omp); - - llvm::ValueMap - differentialAliasScopeDomains; - llvm::ValueMap> - differentialAliasScope; - llvm::MDNode *getDerivativeAliasScope(const llvm::Value *origptr, - ssize_t newptr); - - void setPtrDiffe(llvm::Instruction *orig, llvm::Value *ptr, - llvm::Value *newval, llvm::IRBuilder<> &BuilderM, - llvm::MaybeAlign align, unsigned start, unsigned size, - bool isVolatile, llvm::AtomicOrdering ordering, - llvm::SyncScope::ID syncScope, llvm::Value *mask, - llvm::ArrayRef noAlias, - llvm::ArrayRef scopes); - -private: - llvm::BasicBlock *originalForReverseBlock(llvm::BasicBlock &BB2) const; - -public: - //! This cache stores blocks we may insert as part of getReverseOrLatchMerge - //! to handle inverse iv iteration - // As we don't want to create redundant blocks, we use this convenient cache - std::map, - llvm::BasicBlock *> - newBlocksForLoop_cache; - - //! This cache stores a rematerialized forward pass in the loop - //! specified. The key is the loop header. - std::map - rematerializedLoops_cache; - llvm::BasicBlock *getReverseOrLatchMerge(llvm::BasicBlock *BB, - llvm::BasicBlock *branchingBlock); - -private: - //! Given a loop `lc`, create the rematerialization blocks for the reverse - //! pass, if required, caching if already created. This function will return - //! the new block for the rematerialized loop entry to branch to, if created. - //! Otherwise it will return nullptr. - llvm::BasicBlock *prepRematerializedLoopEntry(LoopContext &lc); - -public: - void forceContexts(); - - void computeMinCache(); - - bool isOriginalBlock(const llvm::BasicBlock &BB) const; - - llvm::SmallVector - rematerializedPrimalOrShadowAllocations; - - void eraseFictiousPHIs(); - - void forceActiveDetection(); - - bool isConstantValue(llvm::Value *val) const; - - bool isConstantInstruction(const llvm::Instruction *inst) const; - - bool getContext(llvm::BasicBlock *BB, LoopContext &lc); - - void forceAugmentedReturns(); - -private: - // For a given value, a list of basic blocks where an unwrap to has already - // produced a warning. - std::map> UnwrappedWarnings; - -public: - /// if full unwrap, don't just unwrap this instruction, but also its operands, - /// etc - llvm::Value *unwrapM(llvm::Value *const val, llvm::IRBuilder<> &BuilderM, - const llvm::ValueToValueMapTy &available, - UnwrapMode unwrapMode, llvm::BasicBlock *scope = nullptr, - bool permitCache = true) override final; - - void ensureLookupCached(llvm::Instruction *inst, bool shouldFree = true, - llvm::BasicBlock *scope = nullptr, - llvm::MDNode *TBAA = nullptr); - - std::map> - lcssaFixes; - std::map lcssaPHIToOrig; - llvm::Value *fixLCSSA(llvm::Instruction *inst, llvm::BasicBlock *forwardBlock, - bool legalInBlock = false); - - llvm::Value *lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, - const llvm::ValueToValueMapTy &incoming_availalble = - llvm::ValueToValueMapTy(), - bool tryLegalRecomputeCheck = true, - llvm::BasicBlock *scope = nullptr) override; - - llvm::Value *invertPointerM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, - bool nullShadow = false); - - static llvm::Constant *GetOrCreateShadowConstant( - RequestContext context, EnzymeLogic &Logic, llvm::TargetLibraryInfo &TLI, - TypeAnalysis &TA, llvm::Constant *F, DerivativeMode mode, - bool runtimeActivity, unsigned width, bool AtomicAdd); - - static llvm::Constant *GetOrCreateShadowFunction( - RequestContext context, EnzymeLogic &Logic, llvm::TargetLibraryInfo &TLI, - TypeAnalysis &TA, llvm::Function *F, DerivativeMode mode, - bool runtimeActivity, unsigned width, bool AtomicAdd); - - void branchToCorrespondingTarget( - llvm::BasicBlock *ctx, llvm::IRBuilder<> &BuilderM, - const std::map>> - &targetToPreds, - const std::map *replacePHIs = - nullptr); - - void getReverseBuilder(llvm::IRBuilder<> &Builder2, bool original = true); - - void getForwardBuilder(llvm::IRBuilder<> &Builder2); - - static llvm::Type *getShadowType(llvm::Type *ty, unsigned width); - - llvm::Type *getShadowType(llvm::Type *ty); - - //! Helper routine to extract a nested element from a struct/array. This is - // a one dimensional special case of the multi-dim extractMeta below. - static llvm::Value *extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, - unsigned off, const llvm::Twine &name = ""); - - //! Helper routine to extract a nested element from a struct/array. Unlike the - // LLVM instruction, this will attempt to re-use the inserted value, if it - // exists, rather than always creating a new instruction. If fallback is - // true (the default), it will create an instruction if it fails to find an - // appropriate existing value, otherwise it returns nullptr. - static llvm::Value *extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, - llvm::ArrayRef off, - const llvm::Twine &name = "", - bool fallback = true); - - //! Helper routine to get the type of an extraction - static llvm::Type *extractMeta(llvm::Type *T, llvm::ArrayRef off); - - static llvm::Value *recursiveFAdd(llvm::IRBuilder<> &B, llvm::Value *lhs, - llvm::Value *rhs, - llvm::ArrayRef lhs_off = {}, - llvm::ArrayRef rhs_off = {}, - llvm::Value *prev = nullptr, - bool vectorLayer = false); - - /// Unwraps a vector derivative from its internal representation and applies a - /// function f to each element. Return values of f are collected and wrapped. - template - llvm::Value *applyChainRule(llvm::Type *diffType, llvm::IRBuilder<> &Builder, - Func rule, Args... args) { - if (width > 1) { - const int size = sizeof...(args); - llvm::Value *vals[size] = {args...}; - - for (size_t i = 0; i < size; ++i) - if (vals[i]) - assert(llvm::cast(vals[i]->getType()) - ->getNumElements() == width); - - llvm::Type *wrappedType = diffType->isVoidTy() - ? nullptr - : llvm::ArrayType::get(diffType, width); - llvm::Value *res = - diffType->isVoidTy() ? nullptr : llvm::UndefValue::get(wrappedType); - for (unsigned int i = 0; i < getWidth(); ++i) { - auto tup = std::tuple{ - (args ? extractMeta(Builder, args, i) : nullptr)...}; - auto diff = std::apply(rule, std::move(tup)); - if (!diffType->isVoidTy()) - res = Builder.CreateInsertValue(res, diff, {i}); - } - return res; - } else { - return rule(args...); - } - } - - /// Unwraps a vector derivative from its internal representation and applies a - /// function f to each element. Return values of f are collected and wrapped. - template - void applyChainRule(llvm::IRBuilder<> &Builder, Func rule, Args... args) { - if (width > 1) { - const int size = sizeof...(args); - llvm::Value *vals[size] = {args...}; - - for (size_t i = 0; i < size; ++i) - if (vals[i]) - assert(llvm::cast(vals[i]->getType()) - ->getNumElements() == width); - - for (unsigned int i = 0; i < getWidth(); ++i) { - auto tup = std::tuple{ - (args ? extractMeta(Builder, args, i) : nullptr)...}; - std::apply(rule, std::move(tup)); - } - } else { - rule(args...); - } - } - - /// Unwraps an collection of constant vector derivatives from their internal - /// representations and applies a function f to each element. - template - llvm::Value *applyChainRule(llvm::Type *diffType, - llvm::ArrayRef diffs, - llvm::IRBuilder<> &Builder, Func rule) { - if (width > 1) { -#ifndef NDEBUG - for (auto diff : diffs) { - assert(diff); - assert(llvm::cast(diff->getType())->getNumElements() == - width); - } -#endif - llvm::Type *wrappedType = llvm::ArrayType::get(diffType, width); - llvm::Value *res = llvm::UndefValue::get(wrappedType); - for (unsigned int i = 0; i < getWidth(); ++i) { - llvm::SmallVector extracted_diffs; - for (auto diff : diffs) { - extracted_diffs.push_back( - llvm::cast(extractMeta(Builder, diff, i))); - } - auto diff = rule(extracted_diffs); - res = Builder.CreateInsertValue(res, diff, {i}); - } - return res; - } else { - return rule(diffs); - } - } - - bool needsCacheWholeAllocation(const llvm::Value *V) const; -}; - -void SubTransferHelper(GradientUtils *gutils, DerivativeMode Mode, - llvm::Type *secretty, llvm::Intrinsic::ID intrinsic, - unsigned dstalign, unsigned srcalign, unsigned offset, - bool dstConstant, llvm::Value *shadow_dst, - bool srcConstant, llvm::Value *shadow_src, - llvm::Value *length, llvm::Value *isVolatile, - llvm::CallInst *MTI, bool allowForward = true, - bool shadowsLookedUp = false, - bool backwardsShadow = false); -#endif diff --git a/enzyme/Enzyme/InstructionBatcher.cpp b/enzyme/Enzyme/InstructionBatcher.cpp deleted file mode 100644 index 36972f5e79dd..000000000000 --- a/enzyme/Enzyme/InstructionBatcher.cpp +++ /dev/null @@ -1,283 +0,0 @@ -//===- InstructionBatcher.cpp -//--------------------------------------------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains an instruction visitor InstructionBatcher that generates -// the batches all LLVM instructions. -// -//===----------------------------------------------------------------------===// - -#include "InstructionBatcher.h" - -#include "llvm/IR/InstVisitor.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" - -#include "llvm/Support/Casting.h" - -#include "llvm/IR/Constants.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Value.h" - -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/ValueMapper.h" - -#include "DiffeGradientUtils.h" -#include "GradientUtils.h" - -using namespace llvm; - -InstructionBatcher::InstructionBatcher( - Function *oldFunc, Function *newFunc, unsigned width, - ValueMap> &vectorizedValues, - ValueToValueMapTy &originalToNewFn, SmallPtrSetImpl &toVectorize, - EnzymeLogic &Logic) - : hasError(false), vectorizedValues(vectorizedValues), - originalToNewFn(originalToNewFn), toVectorize(toVectorize), width(width), - Logic(Logic) {} - -Value *InstructionBatcher::getNewOperand(unsigned int i, llvm::Value *op) { - if (auto meta = dyn_cast(op)) { - auto md = meta->getMetadata(); - if (auto val = dyn_cast(md)) - return MetadataAsValue::get( - op->getContext(), - ValueAsMetadata::get(getNewOperand(i, val->getValue()))); - } - - if (isa(op)) { - return op; - } else if (isa(op)) { - return op; - } else if (isa(op)) { - llvm::errs() << "unimplelemented GlobalValue!\n"; - llvm_unreachable("unimplelemented GlobalValue!"); - // TODO: !!! - } else if (toVectorize.count(op) != 0) { - auto found = vectorizedValues.find(op); - assert(found != vectorizedValues.end()); - return found->second[i]; - } else { - auto found = originalToNewFn.find(op); - assert(found != originalToNewFn.end()); - return found->second; - } -} - -void InstructionBatcher::visitInstruction(llvm::Instruction &inst) { - auto found = vectorizedValues.find(&inst); - assert(found != vectorizedValues.end()); - auto placeholders = found->second; - Instruction *placeholder = cast(placeholders[0]); - - for (unsigned i = 1; i < width; ++i) { - ValueToValueMapTy vmap; - Instruction *new_inst = placeholder->clone(); - vmap[placeholder] = new_inst; - - for (unsigned j = 0; j < inst.getNumOperands(); ++j) { - Value *op = inst.getOperand(j); - - // Don't allow writing vectors to global memory, loading and splatting a - // global is fine though. - if (isa(op) && !isa(op) && - inst.mayWriteToMemory() && toVectorize.count(op) != 0) { - // TODO: handle buffer access - hasError = true; - EmitFailure("GlobalValueCannotBeVectorized", inst.getDebugLoc(), &inst, - "global variables have to be scalar values", inst); - return; - } - - if (auto meta = dyn_cast(op)) - if (!isa(meta->getMetadata())) - continue; - - Value *new_op = getNewOperand(i, op); - vmap[placeholder->getOperand(j)] = new_op; - } - - if (placeholders.size() == width) { - // Instructions which return a value - Instruction *placeholder = cast(placeholders[i]); - assert(!placeholder->getType()->isVoidTy()); - - ReplaceInstWithInst(placeholder, new_inst); - vectorizedValues[&inst][i] = new_inst; - } else if (placeholders.size() == 1) { - // Instructions which don't return a value - assert(placeholder->getType()->isVoidTy()); - - Instruction *insertionPoint = - placeholder->getNextNode() ? placeholder->getNextNode() : placeholder; - IRBuilder<> Builder2(insertionPoint); - Builder2.SetCurrentDebugLocation(DebugLoc()); - Builder2.Insert(new_inst); - vectorizedValues[&inst].push_back(new_inst); - } else { - llvm_unreachable("Unexpected number of values in mapping"); - } - - RemapInstruction(new_inst, vmap, RF_NoModuleLevelChanges); - - if (!inst.getType()->isVoidTy() && inst.hasName()) - new_inst->setName(inst.getName() + Twine(i)); - } -} - -void InstructionBatcher::visitPHINode(PHINode &phi) { - PHINode *placeholder = cast(vectorizedValues[&phi][0]); - - for (unsigned i = 1; i < width; ++i) { - ValueToValueMapTy vmap; - Instruction *new_phi = placeholder->clone(); - vmap[placeholder] = new_phi; - - for (unsigned j = 0; j < phi.getNumIncomingValues(); ++j) { - Value *orig_block = phi.getIncomingBlock(j); - BasicBlock *new_block = cast(originalToNewFn[orig_block]); - Value *orig_val = phi.getIncomingValue(j); - Value *new_val = getNewOperand(i, orig_val); - - vmap[placeholder->getIncomingValue(j)] = new_val; - vmap[new_block] = new_block; - } - - RemapInstruction(new_phi, vmap, RF_NoModuleLevelChanges); - Instruction *placeholder = cast(vectorizedValues[&phi][i]); - ReplaceInstWithInst(placeholder, new_phi); - new_phi->setName(phi.getName()); - vectorizedValues[&phi][i] = new_phi; - } -} - -void InstructionBatcher::visitSwitchInst(llvm::SwitchInst &inst) { - // TODO: runtime check - hasError = true; - EmitFailure("SwitchConditionCannotBeVectorized", inst.getDebugLoc(), &inst, - "switch conditions have to be scalar values", inst); - return; -} - -void InstructionBatcher::visitBranchInst(llvm::BranchInst &branch) { - // TODO: runtime check - hasError = true; - EmitFailure("BranchConditionCannotBeVectorized", branch.getDebugLoc(), - &branch, "branch conditions have to be scalar values", branch); - return; -} - -void InstructionBatcher::visitReturnInst(llvm::ReturnInst &ret) { - auto found = originalToNewFn.find(ret.getParent()); - assert(found != originalToNewFn.end()); - BasicBlock *nBB = dyn_cast(&*found->second); - IRBuilder<> Builder2 = IRBuilder<>(nBB); - Builder2.SetCurrentDebugLocation(DebugLoc()); - ReturnInst *placeholder = cast(nBB->getTerminator()); - SmallVector rets; - - for (unsigned j = 0; j < ret.getNumOperands(); ++j) { - Value *op = ret.getOperand(j); - for (unsigned i = 0; i < width; ++i) { - Value *new_op = getNewOperand(i, op); - rets.push_back(new_op); - } - } - - if (ret.getNumOperands() != 0) { - auto ret = Builder2.CreateAggregateRet(rets.data(), width); - ret->setDebugLoc(placeholder->getDebugLoc()); - placeholder->eraseFromParent(); - } -} - -void InstructionBatcher::visitCallInst(llvm::CallInst &call) { - auto found = vectorizedValues.find(&call); - assert(found != vectorizedValues.end()); - auto placeholders = found->second; - Instruction *placeholder = cast(placeholders[0]); - IRBuilder<> Builder2(placeholder); - Builder2.SetCurrentDebugLocation(DebugLoc()); - auto orig_func = getFunctionFromCall(&call); - - bool isDefined = !orig_func->isDeclaration(); - - if (!isDefined) - return visitInstruction(call); - - SmallVector args; - SmallVector arg_types; -#if LLVM_VERSION_MAJOR >= 14 - for (unsigned j = 0; j < call.arg_size(); ++j) { -#else - for (unsigned j = 0; j < call.getNumArgOperands(); ++j) { -#endif - Value *op = call.getArgOperand(j); - - if (toVectorize.count(op) != 0) { - Type *aggTy = GradientUtils::getShadowType(op->getType(), width); - Value *agg = UndefValue::get(aggTy); - for (unsigned i = 0; i < width; i++) { - auto found = vectorizedValues.find(op); - assert(found != vectorizedValues.end()); - Value *new_op = found->second[i]; - Builder2.CreateInsertValue(agg, new_op, {i}); - } - args.push_back(agg); - arg_types.push_back(BATCH_TYPE::VECTOR); - } else if (isa(op)) { - args.push_back(op); - arg_types.push_back(BATCH_TYPE::SCALAR); - } else { - auto found = originalToNewFn.find(op); - assert(found != originalToNewFn.end()); - Value *arg = found->second; - args.push_back(arg); - arg_types.push_back(BATCH_TYPE::SCALAR); - } - } - - BATCH_TYPE ret_type = orig_func->getReturnType()->isVoidTy() - ? BATCH_TYPE::SCALAR - : BATCH_TYPE::VECTOR; - - Function *new_func = Logic.CreateBatch(RequestContext(&call, &Builder2), - orig_func, width, arg_types, ret_type); - CallInst *new_call = Builder2.CreateCall(new_func->getFunctionType(), - new_func, args, call.getName()); - - new_call->setDebugLoc(placeholder->getDebugLoc()); - - if (!call.getType()->isVoidTy()) { - for (unsigned i = 0; i < width; ++i) { - Instruction *placeholder = dyn_cast(placeholders[i]); - ExtractValueInst *ret = ExtractValueInst::Create( - new_call, {i}, - "unwrap" + (call.hasName() ? "." + call.getName() + Twine(i) : "")); - ReplaceInstWithInst(placeholder, ret); - vectorizedValues[&call][i] = ret; - } - } else { - placeholder->replaceAllUsesWith(new_call); - placeholder->eraseFromParent(); - } -} diff --git a/enzyme/Enzyme/InstructionBatcher.h b/enzyme/Enzyme/InstructionBatcher.h deleted file mode 100644 index da0677c7c840..000000000000 --- a/enzyme/Enzyme/InstructionBatcher.h +++ /dev/null @@ -1,86 +0,0 @@ -//===- InstructionBatcher.h -//--------------------------------------------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains an instruction visitor InstructionBatcher that generates -// the batches all LLVM instructions. -// -//===----------------------------------------------------------------------===// - -#ifndef INSTRUCTION_BATCHER_H_ -#define INSTRUCTION_BATCHER_H_ - -#include - -#if LLVM_VERSION_MAJOR >= 16 -#define private public -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#undef private -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "llvm/IR/InstVisitor.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/Value.h" - -#include "llvm/ADT/SmallPtrSet.h" - -#include "EnzymeLogic.h" - -class InstructionBatcher final : public llvm::InstVisitor { -public: - bool hasError; - InstructionBatcher( - llvm::Function *oldFunc, llvm::Function *newFunc, unsigned width, - llvm::ValueMap> - &vectorizedValues, - llvm::ValueMap - &originalToNewFn, - llvm::SmallPtrSetImpl &toVectorize, EnzymeLogic &Logic); - -private: - llvm::ValueMap> - &vectorizedValues; - llvm::ValueMap &originalToNewFn; - llvm::SmallPtrSetImpl &toVectorize; - unsigned width; - EnzymeLogic &Logic; - -private: - llvm::Value *getNewOperand(unsigned int i, llvm::Value *op); - -public: - void visitInstruction(llvm::Instruction &inst); - - void visitPHINode(llvm::PHINode &phi); - - void visitSwitchInst(llvm::SwitchInst &inst); - - void visitBranchInst(llvm::BranchInst &branch); - - void visitReturnInst(llvm::ReturnInst &ret); - - void visitCallInst(llvm::CallInst &call); -}; - -#endif diff --git a/enzyme/Enzyme/MustExitScalarEvolution.cpp b/enzyme/Enzyme/MustExitScalarEvolution.cpp deleted file mode 100644 index c1850fe331b7..000000000000 --- a/enzyme/Enzyme/MustExitScalarEvolution.cpp +++ /dev/null @@ -1,1318 +0,0 @@ - -//===- MustExitScalarEvolution.cpp - ScalarEvolution assuming loops -// terminate-=// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file defines MustExitScalarEvolution, a subclass of ScalarEvolution -// that assumes that all loops terminate (and don't loop forever). -// -//===----------------------------------------------------------------------===// - -#include "MustExitScalarEvolution.h" -#include "FunctionUtils.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/ScalarEvolution.h" - -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunused-variable" -#else -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-variable" -#endif - -using namespace llvm; - -bool MustExitScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { - return true; -} - -ScalarEvolution::ExitLimit MustExitScalarEvolution::computeExitLimitFromCond( - const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsExit, - bool AllowPredicates) { - ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates); - return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue, - ControlsExit, AllowPredicates); -} - -MustExitScalarEvolution::MustExitScalarEvolution(llvm::Function &F, - llvm::TargetLibraryInfo &TLI, - llvm::AssumptionCache &AC, - llvm::DominatorTree &DT, - llvm::LoopInfo &LI) - : ScalarEvolution(F, TLI, AC, DT, LI), - GuaranteedUnreachable(getGuaranteedUnreachable(&F)) {} - -ScalarEvolution::ExitLimit MustExitScalarEvolution::computeExitLimit( - const Loop *L, BasicBlock *ExitingBlock, bool AllowPredicates) { - - SmallVector ExitingBlocks; - L->getExitingBlocks(ExitingBlocks); - for (auto &ExitingBlock : ExitingBlocks) { - BasicBlock *Exit = nullptr; - for (auto *SBB : successors(ExitingBlock)) { - if (!L->contains(SBB)) { - if (GuaranteedUnreachable.count(SBB)) - continue; - Exit = SBB; - break; - } - } - if (!Exit) - ExitingBlock = nullptr; - } - ExitingBlocks.erase( - std::remove(ExitingBlocks.begin(), ExitingBlocks.end(), nullptr), - ExitingBlocks.end()); - - assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); - // If our exiting block does not dominate the latch, then its connection with - // loop's exit limit may be far from trivial. - const BasicBlock *Latch = L->getLoopLatch(); - if (!Latch || !DT.dominates(ExitingBlock, Latch)) - return getCouldNotCompute(); - - bool IsOnlyExit = ExitingBlocks.size() == 1; - auto *Term = ExitingBlock->getTerminator(); - if (BranchInst *BI = dyn_cast(Term)) { - assert(BI->isConditional() && "If unconditional, it can't be in loop!"); - bool ExitIfTrue = !L->contains(BI->getSuccessor(0)); - assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) && - "It should have one successor in loop and one exit block!"); - // Proceed to the next level to examine the exit condition expression. - return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue, - /*ControlsExit=*/IsOnlyExit, - AllowPredicates); - } - - if (SwitchInst *SI = dyn_cast(Term)) { - // For switch, make sure that there is a single exit from the loop. - BasicBlock *Exit = nullptr; - for (auto *SBB : successors(ExitingBlock)) - if (!L->contains(SBB)) { - if (GuaranteedUnreachable.count(SBB)) - continue; - if (Exit) // Multiple exit successors. - return getCouldNotCompute(); - Exit = SBB; - } - assert(Exit && "Exiting block must have at least one exit"); - return computeExitLimitFromSingleExitSwitch(L, SI, Exit, - /*ControlsExit=*/IsOnlyExit); - } - - return getCouldNotCompute(); -} - -ScalarEvolution::ExitLimit -MustExitScalarEvolution::computeExitLimitFromSingleExitSwitch( - const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, - bool ControlsOnlyExit) { - assert(!L->contains(ExitingBlock) && "Not an exiting block!"); - - // Give up if the exit is the default dest of a switch. - if (Switch->getDefaultDest() == ExitingBlock) - return getCouldNotCompute(); - - ///! If we're guaranteed unreachable, the default dest does not matter. - if (!GuaranteedUnreachable.count(Switch->getDefaultDest())) - assert(L->contains(Switch->getDefaultDest()) && - "Default case must not exit the loop!"); - const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); - const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); - - // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit); - if (EL.hasAnyInfo()) - return EL; - - return getCouldNotCompute(); -} - -ScalarEvolution::ExitLimit -MustExitScalarEvolution::computeExitLimitFromCondCached( - ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, bool AllowPredicates) { - - if (auto MaybeEL = - Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates)) - return *MaybeEL; - - ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue, - ControlsExit, AllowPredicates); - Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL); - return EL; -} - -ScalarEvolution::ExitLimit -MustExitScalarEvolution::computeExitLimitFromCondImpl( - ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, bool AllowPredicates) { - // Check if the controlling expression for this loop is an And or Or. - if (BinaryOperator *BO = dyn_cast(ExitCond)) { - if (BO->getOpcode() == Instruction::And) { - // Recurse on the operands of the and. - bool EitherMayExit = !ExitIfTrue; - ExitLimit EL0 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(0), ExitIfTrue, - ControlsExit && !EitherMayExit, AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(1), ExitIfTrue, - ControlsExit && !EitherMayExit, AllowPredicates); - const SCEV *BECount = getCouldNotCompute(); - const SCEV *MaxBECount = getCouldNotCompute(); - if (EitherMayExit) { - // Both conditions must be true for the loop to continue executing. - // Choose the less conservative count. - if (EL0.ExactNotTaken == getCouldNotCompute() || - EL1.ExactNotTaken == getCouldNotCompute()) - BECount = getCouldNotCompute(); - else - BECount = - getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); -#if LLVM_VERSION_MAJOR >= 16 - if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) - MaxBECount = EL1.ConstantMaxNotTaken; - else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) - MaxBECount = EL0.ConstantMaxNotTaken; - else - MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken, - EL1.ConstantMaxNotTaken); - } else { - // Both conditions must be true at the same time for the loop to exit. - // For now, be conservative. - if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken) - MaxBECount = EL0.ConstantMaxNotTaken; - if (EL0.ExactNotTaken == EL1.ExactNotTaken) - BECount = EL0.ExactNotTaken; - } -#else - if (EL0.MaxNotTaken == getCouldNotCompute()) - MaxBECount = EL1.MaxNotTaken; - else if (EL1.MaxNotTaken == getCouldNotCompute()) - MaxBECount = EL0.MaxNotTaken; - else - MaxBECount = - getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); - } else { - // Both conditions must be true at the same time for the loop to exit. - // For now, be conservative. - if (EL0.MaxNotTaken == EL1.MaxNotTaken) - MaxBECount = EL0.MaxNotTaken; - if (EL0.ExactNotTaken == EL1.ExactNotTaken) - BECount = EL0.ExactNotTaken; - } -#endif - // There are cases (e.g. PR26207) where computeExitLimitFromCond is able - // to be more aggressive when computing BECount than when computing - // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and - // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and - // EL1.ConstantMaxNotTaken to not. - if (isa(MaxBECount) && - !isa(BECount)) - MaxBECount = getConstant(getUnsignedRangeMax(BECount)); - -#if LLVM_VERSION_MAJOR >= 20 - return ExitLimit(BECount, MaxBECount, MaxBECount, false, - {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)}); -#elif LLVM_VERSION_MAJOR >= 16 - return ExitLimit(BECount, MaxBECount, MaxBECount, false, - {&EL0.Predicates, &EL1.Predicates}); -#else - return ExitLimit(BECount, MaxBECount, false, - {&EL0.Predicates, &EL1.Predicates}); -#endif - } - if (BO->getOpcode() == Instruction::Or) { - // Recurse on the operands of the or. - bool EitherMayExit = ExitIfTrue; - ExitLimit EL0 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(0), ExitIfTrue, - ControlsExit && !EitherMayExit, AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(1), ExitIfTrue, - ControlsExit && !EitherMayExit, AllowPredicates); - const SCEV *BECount = getCouldNotCompute(); - const SCEV *MaxBECount = getCouldNotCompute(); - if (EitherMayExit) { - // Both conditions must be false for the loop to continue executing. - // Choose the less conservative count. - if (EL0.ExactNotTaken == getCouldNotCompute() || - EL1.ExactNotTaken == getCouldNotCompute()) - BECount = getCouldNotCompute(); - else - BECount = - getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); -#if LLVM_VERSION_MAJOR >= 16 - if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) - MaxBECount = EL1.ConstantMaxNotTaken; - else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) - MaxBECount = EL0.ConstantMaxNotTaken; - else - MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken, - EL1.ConstantMaxNotTaken); - } else { - // Both conditions must be false at the same time for the loop to exit. - // For now, be conservative. - if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken) - MaxBECount = EL0.ConstantMaxNotTaken; - if (EL0.ExactNotTaken == EL1.ExactNotTaken) - BECount = EL0.ExactNotTaken; - } -#if LLVM_VERSION_MAJOR >= 20 - return ExitLimit(BECount, MaxBECount, MaxBECount, false, - {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)}); -#else - return ExitLimit(BECount, MaxBECount, MaxBECount, false, - {&EL0.Predicates, &EL1.Predicates}); -#endif -#else - if (EL0.MaxNotTaken == getCouldNotCompute()) - MaxBECount = EL1.MaxNotTaken; - else if (EL1.MaxNotTaken == getCouldNotCompute()) - MaxBECount = EL0.MaxNotTaken; - else - MaxBECount = - getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); - } else { - // Both conditions must be false at the same time for the loop to exit. - // For now, be conservative. - if (EL0.MaxNotTaken == EL1.MaxNotTaken) - MaxBECount = EL0.MaxNotTaken; - if (EL0.ExactNotTaken == EL1.ExactNotTaken) - BECount = EL0.ExactNotTaken; - } - return ExitLimit(BECount, MaxBECount, false, - {&EL0.Predicates, &EL1.Predicates}); -#endif - } - } - - // With an icmp, it may be feasible to compute an exact backedge-taken count. - // Proceed to the next level to examine the icmp. - if (ICmpInst *ExitCondICmp = dyn_cast(ExitCond)) { - ExitLimit EL = - computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit); - if (EL.hasFullInfo() || !AllowPredicates) - return EL; - - // Try again, but use SCEV predicates this time. - return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit, - /*AllowPredicates=*/true); - } - - // Check for a constant condition. These are normally stripped out by - // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to - // preserve the CFG and is temporarily leaving constant conditions - // in place. - if (ConstantInt *CI = dyn_cast(ExitCond)) { - if (ExitIfTrue == !CI->getZExtValue()) - // The backedge is always taken. - return getCouldNotCompute(); - else - // The backedge is never taken. - return getZero(CI->getType()); - } - - // If it's not an integer or pointer comparison then compute it the hard way. - return computeExitCountExhaustively(L, ExitCond, ExitIfTrue); -} - -ScalarEvolution::ExitLimit MustExitScalarEvolution::computeExitLimitFromICmp( - const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsExit, - bool AllowPredicates) { - // If the condition was exit on true, convert the condition to exit on false -#if LLVM_VERSION_MAJOR >= 20 - llvm::CmpPredicate Pred = ExitCond->getPredicate(); -#else - auto Pred = ExitCond->getPredicate(); -#endif - if (ExitIfTrue) - Pred = ExitCond->getInversePredicate(); - const auto OriginalPred = Pred; - -#if LLVM_VERSION_MAJOR < 14 - // Handle common loops like: for (X = "string"; *X; ++X) - if (LoadInst *LI = dyn_cast(ExitCond->getOperand(0))) - if (Constant *RHS = dyn_cast(ExitCond->getOperand(1))) { - ExitLimit ItCnt = computeLoadConstantCompareExitLimit(LI, RHS, L, Pred); - if (ItCnt.hasAnyInfo()) - return ItCnt; - } -#endif - - const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); - const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); - -#define PROP_PHI(LHS) \ - if (auto un = dyn_cast(LHS)) { \ - if (auto pn = dyn_cast_or_null(un->getValue())) { \ - const SCEV *sc = nullptr; \ - bool failed = false; \ - for (auto &a : pn->incoming_values()) { \ - auto subsc = getSCEV(a); \ - if (sc == nullptr) { \ - sc = subsc; \ - continue; \ - } \ - if (subsc != sc) { \ - failed = true; \ - break; \ - } \ - } \ - if (!failed) { \ - LHS = sc; \ - } \ - } \ - } - PROP_PHI(LHS) - PROP_PHI(RHS) - - // Try to evaluate any dependencies out of the loop. - LHS = getSCEVAtScope(LHS, L); - RHS = getSCEVAtScope(RHS, L); - - // At this point, we would like to compute how many iterations of the - // loop the predicate will return true for these inputs. - if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { - // If there is a loop-invariant, force it into the RHS. - std::swap(LHS, RHS); - Pred = ICmpInst::getSwappedPredicate(Pred); - } - - // Simplify the operands before analyzing them. - (void)SimplifyICmpOperands(Pred, LHS, RHS); - - // If we have a comparison of a chrec against a constant, try to use value - // ranges to answer this query. - if (const SCEVConstant *RHSC = dyn_cast(RHS)) - if (const SCEVAddRecExpr *AddRec = dyn_cast(LHS)) - if (AddRec->getLoop() == L) { - // Form the constant range. - ConstantRange CompRange = - ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt()); - - const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); - if (!isa(Ret)) - return Ret; - } - - switch (Pred) { - case ICmpInst::ICMP_NE: { // while (X != Y) - // Convert to: while (X-Y != 0) - ExitLimit EL = - howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, AllowPredicates); - if (EL.hasAnyInfo()) - return EL; - break; - } - case ICmpInst::ICMP_EQ: { // while (X == Y) - // Convert to: while (X-Y == 0) - ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); - if (EL.hasAnyInfo()) - return EL; - break; - } - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLE: - case ICmpInst::ICMP_ULE: { // while (X < Y) - bool IsSigned = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE; - - if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE) { - if (!isa(RHS->getType())) - break; - SmallVector sv = { - RHS, - getConstant(ConstantInt::get(cast(RHS->getType()), 1))}; - // Since this is not an infinite loop by induction, RHS cannot be - // int_max/uint_max Therefore adding 1 does not wrap. - if (IsSigned) - RHS = getAddExpr(sv, SCEV::FlagNSW); - else - RHS = getAddExpr(sv, SCEV::FlagNUW); - } - ExitLimit EL = - howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, AllowPredicates); - if (EL.hasAnyInfo()) - return EL; - break; - } - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGE: - case ICmpInst::ICMP_UGE: { // while (X > Y) - bool IsSigned = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE; - if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) { - if (!isa(RHS->getType())) - break; - SmallVector sv = { - RHS, - getConstant(ConstantInt::get(cast(RHS->getType()), -1))}; - // Since this is not an infinite loop by induction, RHS cannot be - // int_min/uint_min Therefore subtracting 1 does not wrap. - if (IsSigned) - RHS = getAddExpr(sv, SCEV::FlagNSW); - else - RHS = getAddExpr(sv, SCEV::FlagNUW); - } - ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, - AllowPredicates); - if (EL.hasAnyInfo()) - return EL; - break; - } - default: - break; - } - - auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue); - - if (!isa(ExhaustiveCount)) - return ExhaustiveCount; - - return computeShiftCompareExitLimit(ExitCond->getOperand(0), - ExitCond->getOperand(1), L, OriginalPred); -} - -#if LLVM_VERSION_MAJOR >= 13 -static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { - unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); - if (SE->isKnownPositive(Step)) { - *Pred = ICmpInst::ICMP_SLT; - return SE->getConstant(APInt::getSignedMinValue(BitWidth) - - SE->getSignedRangeMax(Step)); - } - if (SE->isKnownNegative(Step)) { - *Pred = ICmpInst::ICMP_SGT; - return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - - SE->getSignedRangeMin(Step)); - } - return nullptr; -} -static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { - unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); - *Pred = ICmpInst::ICMP_ULT; - - return SE->getConstant(APInt::getMinValue(BitWidth) - - SE->getUnsignedRangeMax(Step)); -} - -namespace { - -struct ExtendOpTraitsBase { - typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *, - unsigned); -}; - -// Used to make code generic over signed and unsigned overflow. -template struct ExtendOpTraits { - // Members present: - // - // static const SCEV::NoWrapFlags WrapType; - // - // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; - // - // static const SCEV *getOverflowLimitForStep(const SCEV *Step, - // ICmpInst::Predicate *Pred, - // ScalarEvolution *SE); -}; - -template <> -struct ExtendOpTraits : public ExtendOpTraitsBase { - static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; - - static const GetExtendExprTy GetExtendExpr; - - static const SCEV *getOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { - return getSignedOverflowLimitForStep(Step, Pred, SE); - } -}; - -const ExtendOpTraitsBase::GetExtendExprTy - ExtendOpTraits::GetExtendExpr = - &ScalarEvolution::getSignExtendExpr; - -template <> -struct ExtendOpTraits : public ExtendOpTraitsBase { - static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; - - static const GetExtendExprTy GetExtendExpr; - - static const SCEV *getOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { - return getUnsignedOverflowLimitForStep(Step, Pred, SE); - } -}; - -const ExtendOpTraitsBase::GetExtendExprTy - ExtendOpTraits::GetExtendExpr = - &ScalarEvolution::getZeroExtendExpr; - -} // end anonymous namespace - -static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags) { - return TestFlags == ScalarEvolution::maskFlags(Flags, TestFlags); -}; - -template -static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE, unsigned Depth) { - auto WrapType = ExtendOpTraits::WrapType; - auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; - - const Loop *L = AR->getLoop(); - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*SE); - - // Check for a simple looking step prior to loop entry. - const SCEVAddExpr *SA = dyn_cast(Start); - if (!SA) - return nullptr; - - // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV - // subtraction is expensive. For this purpose, perform a quick and dirty - // difference, by checking for Step in the operand list. - SmallVector DiffOps; - for (const SCEV *Op : SA->operands()) - if (Op != Step) - DiffOps.push_back(Op); - - if (DiffOps.size() == SA->getNumOperands()) - return nullptr; - - // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + - // `Step`: - - // 1. NSW/NUW flags on the step increment. - auto PreStartFlags = - ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); - const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); - const SCEVAddRecExpr *PreAR = dyn_cast( - SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); - - // "{S,+,X} is /" and "the backedge is taken at least once" implies - // "S+X does not sign/unsign-overflow". - // - - const SCEV *BECount = SE->getBackedgeTakenCount(L); - if (PreAR && PreAR->getNoWrapFlags(WrapType) && - !isa(BECount) && SE->isKnownPositive(BECount)) - return PreStart; - - // 2. Direct overflow check on the step operation's expression. - unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); - Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); - const SCEV *OperandExtendedStart = - SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth), - (SE->*GetExtendExpr)(Step, WideTy, Depth)); - if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) { - if (PreAR && AR->getNoWrapFlags(WrapType)) { - // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW - // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then - // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. - SE->setNoWrapFlags(const_cast(PreAR), WrapType); - } - return PreStart; - } - - // 3. Loop precondition. - ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = - ExtendOpTraits::getOverflowLimitForStep(Step, &Pred, SE); - - if (OverflowLimit && - SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) - return PreStart; - - return nullptr; -} - -template -static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE, unsigned Depth) { - auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; - - const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE, Depth); - if (!PreStart) - return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth); - - return SE->getAddExpr( - (SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, Depth), - (SE->*GetExtendExpr)(PreStart, Ty, Depth)); -} - -static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, - SCEVTypes Type, - const ArrayRef Ops, - SCEV::NoWrapFlags Flags) { - using namespace std::placeholders; - - using OBO = OverflowingBinaryOperator; - - bool CanAnalyze = - Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; - (void)CanAnalyze; - assert(CanAnalyze && "don't call from other places!"); - - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = - ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); - - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - auto IsKnownNonNegative = [&](const SCEV *S) { - return SE->isKnownNonNegative(S); - }; - - if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative)) - Flags = - ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - - SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); - - if (SignOrUnsignWrap != SignOrUnsignMask && - (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 && - isa(Ops[0])) { - - auto Opcode = [&] { - switch (Type) { - case scAddExpr: - return Instruction::Add; - case scMulExpr: - return Instruction::Mul; - default: - llvm_unreachable("Unexpected SCEV op."); - } - }(); - - const APInt &C = cast(Ops[0])->getAPInt(); - - // (A C) --> (A C) if the op doesn't sign overflow. - if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { - auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( - Opcode, C, OBO::NoSignedWrap); - if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) - Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); - } - - // (A C) --> (A C) if the op doesn't unsign overflow. - if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { - auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( - Opcode, C, OBO::NoUnsignedWrap); - if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) - Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); - } - } - - // <0,+,nonnegative> is also nuw - // TODO: Add corresponding nsw case - if (Type == scAddRecExpr && hasFlags(Flags, SCEV::FlagNW) && - !hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 && Ops[0]->isZero() && - IsKnownNonNegative(Ops[1])) - Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); - - // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW - if (Type == scMulExpr && !hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2) { - if (auto *UDiv = dyn_cast(Ops[0])) - if (UDiv->getOperand(1) == Ops[1]) - Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); - if (auto *UDiv = dyn_cast(Ops[1])) - if (UDiv->getOperand(1) == Ops[0]) - Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); - } - - return Flags; -} - -ScalarEvolution::ExitLimit MustExitScalarEvolution::howManyLessThans( - const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit, bool AllowPredicates) { -#if LLVM_VERSION_MAJOR >= 20 - SmallVector Predicates; -#else - SmallPtrSet Predicates; -#endif - - const SCEVAddRecExpr *IV = dyn_cast(LHS); - bool PredicatedIV = false; - - auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) { - // Can we prove this loop *must* be UB if overflow of IV occurs? - // Reasoning goes as follows: - // * Suppose the IV did self wrap. - // * If Stride evenly divides the iteration space, then once wrap - // occurs, the loop must revisit the same values. - // * We know that RHS is invariant, and that none of those values - // caused this exit to be taken previously. Thus, this exit is - // dynamically dead. - // * If this is the sole exit, then a dead exit implies the loop - // must be infinite if there are no abnormal exits. - // * If the loop were infinite, then it must either not be mustprogress - // or have side effects. Otherwise, it must be UB. - // * It can't (by assumption), be UB so we have contradicted our - // premise and can conclude the IV did not in fact self-wrap. - if (!isLoopInvariant(RHS, L)) - return false; - - auto *StrideC = dyn_cast(AR->getStepRecurrence(*this)); - if (!StrideC || !StrideC->getAPInt().isPowerOf2()) - return false; - - if (!ControlsExit || !loopHasNoAbnormalExits(L)) - return false; - - return loopIsFiniteByAssumption(L); - }; - - if (!IV) { - if (auto *ZExt = dyn_cast(LHS)) { - const SCEVAddRecExpr *AR = dyn_cast(ZExt->getOperand()); - if (AR && AR->getLoop() == L && AR->isAffine()) { - auto Flags = AR->getNoWrapFlags(); - if (!hasFlags(Flags, SCEV::FlagNW) && canAssumeNoSelfWrap(AR)) { - Flags = setFlags(Flags, SCEV::FlagNW); - - SmallVector Operands{AR->operands()}; - Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); - - setNoWrapFlags(const_cast(AR), Flags); - } - if (AR->hasNoUnsignedWrap()) { - // Emulate what getZeroExtendExpr would have done during construction - // if we'd been able to infer the fact just above at that time. - const SCEV *Step = AR->getStepRecurrence(*this); - Type *Ty = ZExt->getType(); - auto *S = getAddRecExpr( - getExtendAddRecStart(AR, Ty, this, 0), - getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags()); - IV = dyn_cast(S); - } - } - } - } - - if (!IV && AllowPredicates) { - // Try to make this an AddRec using runtime tests, in the first X - // iterations of this loop, where X is the SCEV expression found by the - // algorithm below. - IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); - PredicatedIV = true; - } - - // Avoid weird loops - if (!IV || IV->getLoop() != L || !IV->isAffine()) - return getCouldNotCompute(); - - // A precondition of this method is that the condition being analyzed - // reaches an exiting branch which dominates the latch. Given that, we can - // assume that an increment which violates the nowrap specification and - // produces poison must cause undefined behavior when the resulting poison - // value is branched upon and thus we can conclude that the backedge is - // taken no more often than would be required to produce that poison value. - // Note that a well defined loop can exit on the iteration which violates - // the nowrap specification if there is another exit (either explicit or - // implicit/exceptional) which causes the loop to execute before the - // exiting instruction we're analyzing would trigger UB. - auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW; - bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType); - ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; - - const SCEV *Stride = IV->getStepRecurrence(*this); - - bool PositiveStride = isKnownPositive(Stride); - - // Avoid negative or zero stride values. - if (!PositiveStride) { - // We can compute the correct backedge taken count for loops with unknown - // strides if we can prove that the loop is not an infinite loop with side - // effects. Here's the loop structure we are trying to handle - - // - // i = start - // do { - // A[i] = i; - // i += s; - // } while (i < end); - // - // The backedge taken count for such loops is evaluated as - - // (max(end, start + stride) - start - 1) /u stride - // - // The additional preconditions that we need to check to prove correctness - // of the above formula is as follows - - // - // a) IV is either nuw or nsw depending upon signedness (indicated by the - // NoWrap flag). - // b) the loop is guaranteed to be finite (e.g. is mustprogress and has - // no side effects within the loop) - // c) loop has a single static exit (with no abnormal exits) - // - // Precondition a) implies that if the stride is negative, this is a single - // trip loop. The backedge taken count formula reduces to zero in this case. - // - // Precondition b) and c) combine to imply that if rhs is invariant in L, - // then a zero stride means the backedge can't be taken without executing - // undefined behavior. - // - // The positive stride case is the same as isKnownPositive(Stride) returning - // true (original behavior of the function). - // - if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) || - !loopHasNoAbnormalExits(L)) { - return getCouldNotCompute(); - } - - // This bailout is protecting the logic in computeMaxBECountForLT which - // has not yet been sufficiently auditted or tested with negative strides. - // We used to filter out all known-non-positive cases here, we're in the - // process of being less restrictive bit by bit. - if (IsSigned && isKnownNonPositive(Stride)) - return getCouldNotCompute(); - - if (!isKnownNonZero(Stride)) { - // If we have a step of zero, and RHS isn't invariant in L, we don't know - // if it might eventually be greater than start and if so, on which - // iteration. We can't even produce a useful upper bound. - if (!isLoopInvariant(RHS, L)) - return getCouldNotCompute(); - - // We allow a potentially zero stride, but we need to divide by stride - // below. Since the loop can't be infinite and this check must control - // the sole exit, we can infer the exit must be taken on the first - // iteration (e.g. backedge count = 0) if the stride is zero. Given that, - // we know the numerator in the divides below must be zero, so we can - // pick an arbitrary non-zero value for the denominator (e.g. stride) - // and produce the right result. - // FIXME: Handle the case where Stride is poison? - auto wouldZeroStrideBeUB = [&]() { - // Proof by contradiction. Suppose the stride were zero. If we can - // prove that the backedge *is* taken on the first iteration, then since - // we know this condition controls the sole exit, we must have an - // infinite loop. We can't have a (well defined) infinite loop per - // check just above. - // Note: The (Start - Stride) term is used to get the start' term from - // (start' + stride,+,stride). Remember that we only care about the - // result of this expression when stride == 0 at runtime. - auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride); - return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS); - }; - if (!wouldZeroStrideBeUB()) { - Stride = getUMaxExpr(Stride, getOne(Stride->getType())); - } - } - } else if (!Stride->isOne() && !NoWrap) { - auto isUBOnWrap = [&]() { - // From no-self-wrap, we need to then prove no-(un)signed-wrap. This - // follows trivially from the fact that every (un)signed-wrapped, but - // not self-wrapped value must be LT than the last value before - // (un)signed wrap. Since we know that last value didn't exit, nor - // will any smaller one. - return canAssumeNoSelfWrap(IV); - }; - - // Avoid proven overflow cases: this will ensure that the backedge taken - // count will not generate any unsigned overflow. Relaxed no-overflow - // conditions exploit NoWrapFlags, allowing to optimize in presence of - // undefined behaviors like the case of C language. - if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap()) - return getCouldNotCompute(); - } - - // On all paths just preceeding, we established the following invariant: - // IV can be assumed not to overflow up to and including the exiting - // iteration. We proved this in one of two ways: - // 1) We can show overflow doesn't occur before the exiting iteration - // 1a) canIVOverflowOnLT, and b) step of one - // 2) We can show that if overflow occurs, the loop must execute UB - // before any possible exit. - // Note that we have not yet proved RHS invariant (in general). - - const SCEV *Start = IV->getStart(); - - // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond. - // If we convert to integers, isLoopEntryGuardedByCond will miss some cases. - // Use integer-typed versions for actual computation; we can't subtract - // pointers in general. - const SCEV *OrigStart = Start; - const SCEV *OrigRHS = RHS; - if (Start->getType()->isPointerTy()) { - Start = getLosslessPtrToIntExpr(Start); - if (isa(Start)) - return Start; - } - if (RHS->getType()->isPointerTy()) { - RHS = getLosslessPtrToIntExpr(RHS); - if (isa(RHS)) - return RHS; - } - - // When the RHS is not invariant, we do not know the end bound of the loop and - // cannot calculate the ExactBECount needed by ExitLimit. However, we can - // calculate the MaxBECount, given the start, stride and max value for the end - // bound of the loop (RHS), and the fact that IV does not overflow (which is - // checked above). - if (!isLoopInvariant(RHS, L)) { - const SCEV *MaxBECount = computeMaxBECountForLT( - Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); -#if LLVM_VERSION_MAJOR >= 16 - return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, - MaxBECount, false /*MaxOrZero*/, Predicates); -#else - return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, - false /*MaxOrZero*/, Predicates); -#endif - } - - // We use the expression (max(End,Start)-Start)/Stride to describe the - // backedge count, as if the backedge is taken at least once max(End,Start) - // is End and so the result is as above, and if not max(End,Start) is Start - // so we get a backedge count of zero. - const SCEV *BECount = nullptr; - auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride); - assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!"); - assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!"); - assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!"); - // Can we prove (max(RHS,Start) > Start - Stride? - if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) && - isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) { - // In this case, we can use a refined formula for computing backedge taken - // count. The general formula remains: - // "End-Start /uceiling Stride" where "End = max(RHS,Start)" - // We want to use the alternate formula: - // "((End - 1) - (Start - Stride)) /u Stride" - // Let's do a quick case analysis to show these are equivalent under - // our precondition that max(RHS,Start) > Start - Stride. - // * For RHS <= Start, the backedge-taken count must be zero. - // "((End - 1) - (Start - Stride)) /u Stride" reduces to - // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to - // "Stride - 1 /u Stride" which is indeed zero for all non-zero values - // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing - // this to the stride of 1 case. - // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride". - // "((End - 1) - (Start - Stride)) /u Stride" reduces to - // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to - // "((RHS - (Start - Stride) - 1) /u Stride". - // Our preconditions trivially imply no overflow in that form. - const SCEV *MinusOne = getMinusOne(Stride->getType()); - const SCEV *Numerator = - getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride)); - BECount = getUDivExpr(Numerator, Stride); - } - - const SCEV *BECountIfBackedgeTaken = nullptr; - if (!BECount) { - auto canProveRHSGreaterThanEqualStart = [&]() { - auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; - if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart)) - return true; - - // (RHS > Start - 1) implies RHS >= Start. - // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if - // "Start - 1" doesn't overflow. - // * For signed comparison, if Start - 1 does overflow, it's equal - // to INT_MAX, and "RHS >s INT_MAX" is trivially false. - // * For unsigned comparison, if Start - 1 does overflow, it's equal - // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false. - // - // FIXME: Should isLoopEntryGuardedByCond do this for us? - auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; - auto *StartMinusOne = - getAddExpr(OrigStart, getMinusOne(OrigStart->getType())); - return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne); - }; - - // If we know that RHS >= Start in the context of loop, then we know that - // max(RHS, Start) = RHS at this point. - const SCEV *End; - if (canProveRHSGreaterThanEqualStart()) { - End = RHS; - } else { - // If RHS < Start, the backedge will be taken zero times. So in - // general, we can write the backedge-taken count as: - // - // RHS >= Start ? ceil(RHS - Start) / Stride : 0 - // - // We convert it to the following to make it more convenient for SCEV: - // - // ceil(max(RHS, Start) - Start) / Stride - End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); - - // See what would happen if we assume the backedge is taken. This is - // used to compute MaxBECount. - BECountIfBackedgeTaken = - getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride); - } - - // At this point, we know: - // - // 1. If IsSigned, Start <=s End; otherwise, Start <=u End - // 2. The index variable doesn't overflow. - // - // Therefore, we know N exists such that - // (Start + Stride * N) >= End, and computing "(Start + Stride * N)" - // doesn't overflow. - // - // Using this information, try to prove whether the addition in - // "(Start - End) + (Stride - 1)" has unsigned overflow. - const SCEV *One = getOne(Stride->getType()); - bool MayAddOverflow = [&] { - if (auto *StrideC = dyn_cast(Stride)) { - if (StrideC->getAPInt().isPowerOf2()) { - // Suppose Stride is a power of two, and Start/End are unsigned - // integers. Let UMAX be the largest representable unsigned - // integer. - // - // By the preconditions of this function, we know - // "(Start + Stride * N) >= End", and this doesn't overflow. - // As a formula: - // - // End <= (Start + Stride * N) <= UMAX - // - // Subtracting Start from all the terms: - // - // End - Start <= Stride * N <= UMAX - Start - // - // Since Start is unsigned, UMAX - Start <= UMAX. Therefore: - // - // End - Start <= Stride * N <= UMAX - // - // Stride * N is a multiple of Stride. Therefore, - // - // End - Start <= Stride * N <= UMAX - (UMAX mod Stride) - // - // Since Stride is a power of two, UMAX + 1 is divisible by Stride. - // Therefore, UMAX mod Stride == Stride - 1. So we can write: - // - // End - Start <= Stride * N <= UMAX - Stride - 1 - // - // Dropping the middle term: - // - // End - Start <= UMAX - Stride - 1 - // - // Adding Stride - 1 to both sides: - // - // (End - Start) + (Stride - 1) <= UMAX - // - // In other words, the addition doesn't have unsigned overflow. - // - // A similar proof works if we treat Start/End as signed values. - // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to - // use signed max instead of unsigned max. Note that we're trying - // to prove a lack of unsigned overflow in either case. - return false; - } - } - if (Start == Stride || Start == getMinusSCEV(Stride, One)) { - // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1. - // If !IsSigned, 0 (BECount)) { - MaxBECount = BECount; - } else if (BECountIfBackedgeTaken && - isa(BECountIfBackedgeTaken)) { - // If we know exactly how many times the backedge will be taken if it's - // taken at least once, then the backedge count will either be that or - // zero. - MaxBECount = BECountIfBackedgeTaken; - MaxOrZero = true; - } else { - MaxBECount = computeMaxBECountForLT( - Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); - } - - if (isa(MaxBECount) && - !isa(BECount)) - MaxBECount = getConstant(getUnsignedRangeMax(BECount)); -#if LLVM_VERSION_MAJOR >= 16 - return ExitLimit(BECount, MaxBECount, MaxBECount, MaxOrZero, Predicates); -#else - return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); -#endif -} -#else -ScalarEvolution::ExitLimit MustExitScalarEvolution::howManyLessThans( - const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit, bool AllowPredicates) { - SmallPtrSet Predicates; - - const SCEVAddRecExpr *IV = dyn_cast(LHS); - - if (!IV && AllowPredicates) { - // Try to make this an AddRec using runtime tests, in the first X - // iterations of this loop, where X is the SCEV expression found by the - // algorithm below. - IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); - } - - // Avoid weird loops - if (!IV || IV->getLoop() != L || !IV->isAffine()) - return getCouldNotCompute(); - - bool NoWrap = ControlsExit && true; // changed this to assume no wrap for inc - // IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); - - const SCEV *Stride = IV->getStepRecurrence(*this); - - bool PositiveStride = isKnownPositive(Stride); - - // Avoid negative or zero stride values. - if (!PositiveStride) { - // We can compute the correct backedge taken count for loops with unknown - // strides if we can prove that the loop is not an infinite loop with side - // effects. Here's the loop structure we are trying to handle - - // - // i = start - // do { - // A[i] = i; - // i += s; - // } while (i < end); - // - // The backedge taken count for such loops is evaluated as - - // (max(end, start + stride) - start - 1) /u stride - // - // The additional preconditions that we need to check to prove correctness - // of the above formula is as follows - - // - // a) IV is either nuw or nsw depending upon signedness (indicated by the - // NoWrap flag). - // b) loop is single exit with no side effects. // dont need this - // - // - // Precondition a) implies that if the stride is negative, this is a single - // trip loop. The backedge taken count formula reduces to zero in this case. - // - // Precondition b) implies that the unknown stride cannot be zero otherwise - // we have UB. - // - // The positive stride case is the same as isKnownPositive(Stride) returning - // true (original behavior of the function). - // - // We want to make sure that the stride is truly unknown as there are edge - // cases where ScalarEvolution propagates no wrap flags to the - // post-increment/decrement IV even though the increment/decrement operation - // itself is wrapping. The computed backedge taken count may be wrong in - // such cases. This is prevented by checking that the stride is not known to - // be either positive or non-positive. For example, no wrap flags are - // propagated to the post-increment IV of this loop with a trip count of 2 - - // - // unsigned char i; - // for(i=127; i<128; i+=129) - // A[i] = i; - // - if (!NoWrap) // THIS LINE CHANGED - return getCouldNotCompute(); - } else if (!Stride->isOne() && - doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) - // Avoid proven overflow cases: this will ensure that the backedge taken - // count will not generate any unsigned overflow. Relaxed no-overflow - // conditions exploit NoWrapFlags, allowing to optimize in presence of - // undefined behaviors like the case of C language. - return getCouldNotCompute(); - - ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; - const SCEV *Start = IV->getStart(); - const SCEV *End = RHS; - // When the RHS is not invariant, we do not know the end bound of the loop and - // cannot calculate the ExactBECount needed by ExitLimit. However, we can - // calculate the MaxBECount, given the start, stride and max value for the end - // bound of the loop (RHS), and the fact that IV does not overflow (which is - // checked above). - if (!isLoopInvariant(RHS, L)) { - const SCEV *MaxBECount = computeMaxBECountForLT( - Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); - return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, - false /*MaxOrZero*/, Predicates); - } - // If the backedge is taken at least once, then it will be taken - // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start - // is the LHS value of the less-than comparison the first time it is evaluated - // and End is the RHS. - const SCEV *BECountIfBackedgeTaken = - computeBECount(getMinusSCEV(End, Start), Stride, false); - // If the loop entry is guarded by the result of the backedge test of the - // first loop iteration, then we know the backedge will be taken at least - // once and so the backedge taken count is as above. If not then we use the - // expression (max(End,Start)-Start)/Stride to describe the backedge count, - // as if the backedge is taken at least once max(End,Start) is End and so the - // result is as above, and if not max(End,Start) is Start so we get a backedge - // count of zero. - const SCEV *BECount; - if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) - BECount = BECountIfBackedgeTaken; - else { - End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); - BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); - } - - const SCEV *MaxBECount; - bool MaxOrZero = false; - if (isa(BECount)) - MaxBECount = BECount; - else if (isa(BECountIfBackedgeTaken)) { - // If we know exactly how many times the backedge will be taken if it's - // taken at least once, then the backedge count will either be that or - // zero. - MaxBECount = BECountIfBackedgeTaken; - MaxOrZero = true; - } else { - MaxBECount = computeMaxBECountForLT( - Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); - } - - if (isa(MaxBECount) && - !isa(BECount)) - MaxBECount = getConstant(getUnsignedRangeMax(BECount)); - - return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); -} - -#ifdef __clang__ -#pragma clang diagnostic pop -#else -#pragma GCC diagnostic pop -#endif - -#endif diff --git a/enzyme/Enzyme/MustExitScalarEvolution.h b/enzyme/Enzyme/MustExitScalarEvolution.h deleted file mode 100644 index 27bba68f8d5e..000000000000 --- a/enzyme/Enzyme/MustExitScalarEvolution.h +++ /dev/null @@ -1,89 +0,0 @@ - -//===- MustExitScalarEvolution.h - ScalarEvolution assuming loops terminate-=// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file declares MustExitScalarEvolution, a subclass of ScalarEvolution -// that assumes that all loops terminate (and don't loop forever). -// -//===----------------------------------------------------------------------===// - -#ifndef ENZYME_MUST_EXIT_SCALAR_EVOLUTION_H_ -#define ENZYME_MUST_EXIT_SCALAR_EVOLUTION_H_ - -#include - -#if LLVM_VERSION_MAJOR >= 16 -#define private public -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#undef private -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "llvm/IR/Dominators.h" - -class MustExitScalarEvolution final : public llvm::ScalarEvolution { -public: - llvm::SmallPtrSet GuaranteedUnreachable; - using ScalarEvolution::ScalarEvolution; - - MustExitScalarEvolution(llvm::Function &F, llvm::TargetLibraryInfo &TLI, - llvm::AssumptionCache &AC, llvm::DominatorTree &DT, - llvm::LoopInfo &LI); - ScalarEvolution::ExitLimit computeExitLimit(const llvm::Loop *L, - llvm::BasicBlock *ExitingBlock, - bool AllowPredicates); - - ScalarEvolution::ExitLimit computeExitLimitFromCond(const llvm::Loop *L, - llvm::Value *ExitCond, - bool ExitIfTrue, - bool ControlsExit, - bool AllowPredicates); - - ScalarEvolution::ExitLimit - computeExitLimitFromCondCached(ExitLimitCacheTy &Cache, const llvm::Loop *L, - llvm::Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, bool AllowPredicates); - ScalarEvolution::ExitLimit - computeExitLimitFromCondImpl(ExitLimitCacheTy &Cache, const llvm::Loop *L, - llvm::Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, bool AllowPredicates); - - ScalarEvolution::ExitLimit - computeExitLimitFromICmp(const llvm::Loop *L, llvm::ICmpInst *ExitCond, - bool ExitIfTrue, bool ControlsExit, - bool AllowPredicates = false); - - bool loopIsFiniteByAssumption(const llvm::Loop *L); - - ScalarEvolution::ExitLimit howManyLessThans(const llvm::SCEV *LHS, - const llvm::SCEV *RHS, - const llvm::Loop *L, - bool IsSigned, bool ControlsExit, - bool AllowPredicates); - - ScalarEvolution::ExitLimit computeExitLimitFromSingleExitSwitch( - const llvm::Loop *L, llvm::SwitchInst *Switch, - llvm::BasicBlock *ExitingBB, bool IsSubExpr); -}; - -#endif diff --git a/enzyme/Enzyme/TraceGenerator.cpp b/enzyme/Enzyme/TraceGenerator.cpp deleted file mode 100644 index bb70dd453f71..000000000000 --- a/enzyme/Enzyme/TraceGenerator.cpp +++ /dev/null @@ -1,424 +0,0 @@ -//===- TraceGenerator.h - Trace sample statements and calls --------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains an instruction visitor that generates probabilistic -// programming traces for call sites and sample statements. -// -//===----------------------------------------------------------------------===// - -#include "TraceGenerator.h" - -#include "llvm/ADT/SmallVector.h" - -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -#include "llvm/Transforms/Utils/BasicBlockUtils.h" - -#include "FunctionUtils.h" -#include "TraceInterface.h" -#include "TraceUtils.h" -#include "Utils.h" - -using namespace llvm; - -TraceGenerator::TraceGenerator( - EnzymeLogic &Logic, TraceUtils *tutils, bool autodiff, - ValueMap &originalToNewFn, - const SmallPtrSetImpl &generativeFunctions, - const StringSet<> &activeRandomVariables) - : Logic(Logic), tutils(tutils), autodiff(autodiff), - originalToNewFn(originalToNewFn), - generativeFunctions(generativeFunctions), - activeRandomVariables(activeRandomVariables) { - assert(tutils); -}; - -void TraceGenerator::visitFunction(Function &F) { - if (mode == ProbProgMode::Likelihood) - return; - - auto fn = tutils->newFunc; - auto entry = getFirstNonPHIOrDbgOrLifetime(&fn->getEntryBlock()); - - while (isa(entry) && entry->getNextNode()) { - entry = entry->getNextNode(); - } - - IRBuilder<> Builder(entry); - - tutils->InsertFunction(Builder, tutils->newFunc); - - auto attributes = fn->getAttributes(); - for (size_t i = 0; i < fn->getFunctionType()->getNumParams(); ++i) { - bool shouldSkipParam = - attributes.hasParamAttr(i, TraceUtils::TraceParameterAttribute) || - attributes.hasParamAttr(i, - TraceUtils::ObservationsParameterAttribute) || - attributes.hasParamAttr(i, TraceUtils::LikelihoodParameterAttribute); - if (shouldSkipParam) - continue; - - auto arg = fn->arg_begin() + i; -#if LLVM_VERSION_MAJOR >= 17 - auto name = Builder.CreateGlobalString(arg->getName()); -#else - auto name = Builder.CreateGlobalStringPtr(arg->getName()); -#endif - - auto Outlined = [](IRBuilder<> &OutlineBuilder, TraceUtils *OutlineTutils, - ArrayRef Arguments) { - OutlineTutils->InsertArgument(OutlineBuilder, Arguments[0], Arguments[1]); - OutlineBuilder.CreateRetVoid(); - }; - - auto call = tutils->CreateOutlinedFunction( - Builder, Outlined, Builder.getVoidTy(), {name, arg}, false, - "outline_insert_argument"); - - call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(F.getContext(), "enzyme_insert_argument")); - call->addAttributeAtIndex(AttributeList::FunctionIndex, - Attribute::get(F.getContext(), "enzyme_active")); - if (autodiff) { - auto gradient_setter = ValueAsMetadata::get( - tutils->interface->insertArgumentGradient(Builder)); - auto gradient_setter_node = - MDNode::get(F.getContext(), {gradient_setter}); - - call->setMetadata("enzyme_gradient_setter", gradient_setter_node); - } - } -} - -void TraceGenerator::handleObserveCall(CallInst &call, CallInst *new_call) { - IRBuilder<> Builder(new_call); - - SmallVector Args( - make_range(new_call->arg_begin() + 2, new_call->arg_end())); - - Value *observed = new_call->getArgOperand(0); - Function *likelihoodfn = GetFunctionFromValue(new_call->getArgOperand(1)); - Value *address = new_call->getArgOperand(2); - - StringRef const_address; - bool is_address_const = getConstantStringInfo(address, const_address); - bool is_random_var_active = - activeRandomVariables.empty() || - (is_address_const && activeRandomVariables.count(const_address)); - Attribute activity_attribute = Attribute::get( - call.getContext(), - is_random_var_active ? "enzyme_active" : "enzyme_inactive_val"); - - // calculate and accumulate log likelihood - Args.push_back(observed); - - auto score = Builder.CreateCall(likelihoodfn->getFunctionType(), likelihoodfn, - ArrayRef(Args).slice(1), - "likelihood." + call.getName()); - - score->addAttributeAtIndex(AttributeList::FunctionIndex, activity_attribute); - - auto log_prob_sum = Builder.CreateLoad( - Builder.getDoubleTy(), tutils->getLikelihood(), "log_prob_sum"); - auto acc = Builder.CreateFAdd(log_prob_sum, score); - Builder.CreateStore(acc, tutils->getLikelihood()); - - // create outlined trace function - if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { - Value *trace_args[] = {address, score, observed}; - - auto OutlinedTrace = [](IRBuilder<> &OutlineBuilder, - TraceUtils *OutlineTutils, - ArrayRef Arguments) { - OutlineTutils->InsertChoice(OutlineBuilder, Arguments[0], Arguments[1], - Arguments[2]); - OutlineBuilder.CreateRetVoid(); - }; - - auto trace_call = tutils->CreateOutlinedFunction( - Builder, OutlinedTrace, Builder.getVoidTy(), trace_args, false, - "outline_insert_choice"); - - trace_call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call.getContext(), "enzyme_inactive")); - trace_call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call.getContext(), "enzyme_notypeanalysis")); - } - - if (!call.getType()->isVoidTy()) { - observed->takeName(new_call); - new_call->replaceAllUsesWith(observed); - } - new_call->eraseFromParent(); -} - -void TraceGenerator::handleSampleCall(CallInst &call, CallInst *new_call) { - // create outlined sample function - SmallVector Args( - make_range(new_call->arg_begin() + 2, new_call->arg_end())); - - Function *samplefn = GetFunctionFromValue(new_call->getArgOperand(0)); - Function *likelihoodfn = GetFunctionFromValue(new_call->getArgOperand(1)); - Value *address = new_call->getArgOperand(2); - - IRBuilder<> Builder(new_call); - - auto OutlinedSample = [samplefn](IRBuilder<> &OutlineBuilder, - TraceUtils *OutlineTutils, - ArrayRef Arguments) { - auto choice = OutlineTutils->SampleOrCondition( - OutlineBuilder, samplefn, Arguments.slice(1), Arguments[0], - samplefn->getName()); - OutlineBuilder.CreateRet(choice); - }; - - const char *mode_str; - switch (mode) { - case ProbProgMode::Likelihood: - case ProbProgMode::Trace: - mode_str = "sample"; - break; - case ProbProgMode::Condition: - mode_str = "condition"; - break; - } - - auto sample_call = tutils->CreateOutlinedFunction( - Builder, OutlinedSample, samplefn->getReturnType(), Args, false, - Twine(mode_str) + "_" + samplefn->getName()); - - StringRef const_address; - bool is_address_const = getConstantStringInfo(address, const_address); - bool is_random_var_active = - activeRandomVariables.empty() || - (is_address_const && activeRandomVariables.count(const_address)); - Attribute activity_attribute = Attribute::get( - call.getContext(), - is_random_var_active ? "enzyme_active" : "enzyme_inactive_val"); - - sample_call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call.getContext(), "enzyme_sample")); - sample_call->addAttributeAtIndex(AttributeList::FunctionIndex, - activity_attribute); - - if (autodiff && - (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition)) { - auto gradient_setter = - ValueAsMetadata::get(tutils->interface->insertChoiceGradient(Builder)); - auto gradient_setter_node = - MDNode::get(call.getContext(), {gradient_setter}); - - sample_call->setMetadata("enzyme_gradient_setter", gradient_setter_node); - } - - // calculate and accumulate log likelihood - Args.push_back(sample_call); - - auto score = Builder.CreateCall(likelihoodfn->getFunctionType(), likelihoodfn, - ArrayRef(Args).slice(1), - "likelihood." + call.getName()); - - score->addAttributeAtIndex(AttributeList::FunctionIndex, activity_attribute); - - auto log_prob_sum = Builder.CreateLoad( - Builder.getDoubleTy(), tutils->getLikelihood(), "log_prob_sum"); - auto acc = Builder.CreateFAdd(log_prob_sum, score); - Builder.CreateStore(acc, tutils->getLikelihood()); - - // create outlined trace function - - if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { - Value *trace_args[] = {address, score, sample_call}; - - auto OutlinedTrace = [](IRBuilder<> &OutlineBuilder, - TraceUtils *OutlineTutils, - ArrayRef Arguments) { - OutlineTutils->InsertChoice(OutlineBuilder, Arguments[0], Arguments[1], - Arguments[2]); - OutlineBuilder.CreateRetVoid(); - }; - - auto trace_call = tutils->CreateOutlinedFunction( - Builder, OutlinedTrace, Builder.getVoidTy(), trace_args, false, - "outline_insert_choice"); - - trace_call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call.getContext(), "enzyme_inactive")); - trace_call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call.getContext(), "enzyme_notypeanalysis")); - } - - sample_call->takeName(new_call); - new_call->replaceAllUsesWith(sample_call); - new_call->eraseFromParent(); -} - -void TraceGenerator::handleArbitraryCall(CallInst &call, CallInst *new_call) { - IRBuilder<> Builder(new_call); - - SmallVector args; - for (auto it = new_call->arg_begin(); it != new_call->arg_end(); it++) { - args.push_back(*it); - } - - Function *called = getFunctionFromCall(&call); - assert(called); - - Function *samplefn = Logic.CreateTrace( - RequestContext(&call, &Builder), called, tutils->sampleFunctions, - tutils->observeFunctions, activeRandomVariables, mode, autodiff, - tutils->interface); - - Instruction *replacement; - switch (mode) { - case ProbProgMode::Likelihood: { - SmallVector args_and_likelihood(args); - args_and_likelihood.push_back(tutils->getLikelihood()); - replacement = - Builder.CreateCall(samplefn->getFunctionType(), samplefn, - args_and_likelihood, "eval." + called->getName()); - break; - } - case ProbProgMode::Trace: { - auto trace = tutils->CreateTrace(Builder); -#if LLVM_VERSION_MAJOR >= 17 - auto address = Builder.CreateGlobalString( - (call.getName() + "." + called->getName()).str()); -#else - auto address = Builder.CreateGlobalStringPtr( - (call.getName() + "." + called->getName()).str()); -#endif - - SmallVector args_and_trace(args); - args_and_trace.push_back(tutils->getLikelihood()); - args_and_trace.push_back(trace); - replacement = - Builder.CreateCall(samplefn->getFunctionType(), samplefn, - args_and_trace, "trace." + called->getName()); - - tutils->InsertCall(Builder, address, trace); - break; - } - case ProbProgMode::Condition: { - auto trace = tutils->CreateTrace(Builder); -#if LLVM_VERSION_MAJOR >= 17 - auto address = Builder.CreateGlobalString( - (call.getName() + "." + called->getName()).str()); -#else - auto address = Builder.CreateGlobalStringPtr( - (call.getName() + "." + called->getName()).str()); -#endif - - Instruction *hasCall = - tutils->HasCall(Builder, address, "has.call." + call.getName()); - Instruction *ThenTerm, *ElseTerm; - Value *ElseTracecall, *ThenTracecall; - SplitBlockAndInsertIfThenElse(hasCall, new_call, &ThenTerm, &ElseTerm); - - new_call->getParent()->setName(hasCall->getParent()->getName() + ".cntd"); - - Builder.SetInsertPoint(ThenTerm); - { - ThenTerm->getParent()->setName("condition." + call.getName() + - ".with.trace"); - SmallVector args_and_cond(args); - auto observations = - tutils->GetTrace(Builder, address, called->getName() + ".subtrace"); - args_and_cond.push_back(tutils->getLikelihood()); - args_and_cond.push_back(observations); - args_and_cond.push_back(trace); - ThenTracecall = - Builder.CreateCall(samplefn->getFunctionType(), samplefn, - args_and_cond, "condition." + called->getName()); - } - - Builder.SetInsertPoint(ElseTerm); - { - ElseTerm->getParent()->setName("condition." + call.getName() + - ".without.trace"); - SmallVector args_and_null(args); - auto observations = ConstantPointerNull::get(cast( - tutils->getTraceInterface()->newTraceTy()->getReturnType())); - args_and_null.push_back(tutils->getLikelihood()); - args_and_null.push_back(observations); - args_and_null.push_back(trace); - ElseTracecall = - Builder.CreateCall(samplefn->getFunctionType(), samplefn, - args_and_null, "trace." + called->getName()); - } - - Builder.SetInsertPoint(new_call); - auto phi = Builder.CreatePHI(samplefn->getFunctionType()->getReturnType(), - 2, call.getName()); - phi->addIncoming(ThenTracecall, ThenTerm->getParent()); - phi->addIncoming(ElseTracecall, ElseTerm->getParent()); - replacement = phi; - - tutils->InsertCall(Builder, address, trace); - break; - } - } - - replacement->takeName(new_call); - new_call->replaceAllUsesWith(replacement); - new_call->eraseFromParent(); -} - -void TraceGenerator::visitCallInst(CallInst &call) { - auto fn = getFunctionFromCall(&call); - - if (!generativeFunctions.count(fn)) - return; - - CallInst *new_call = dyn_cast(originalToNewFn[&call]); - - if (tutils->isSampleCall(&call)) { - handleSampleCall(call, new_call); - } else if (tutils->isObserveCall(&call)) { - handleObserveCall(call, new_call); - } else { - handleArbitraryCall(call, new_call); - } -} - -void TraceGenerator::visitReturnInst(ReturnInst &ret) { - - if (!ret.getReturnValue()) - return; - - ReturnInst *new_ret = dyn_cast(originalToNewFn[&ret]); - - IRBuilder<> Builder(new_ret); - tutils->InsertReturn(Builder, new_ret->getReturnValue()); -} diff --git a/enzyme/Enzyme/TraceGenerator.h b/enzyme/Enzyme/TraceGenerator.h deleted file mode 100644 index 678da82f9a3f..000000000000 --- a/enzyme/Enzyme/TraceGenerator.h +++ /dev/null @@ -1,67 +0,0 @@ -//===- TraceGenerator.h - Trace sample statements and calls --------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains an instruction visitor that generates probabilistic -// programming traces for call sites and sample statements. -// -//===----------------------------------------------------------------------===// - -#ifndef TraceGenerator_h -#define TraceGenerator_h - -#include "llvm/IR/InstVisitor.h" -#include "llvm/IR/Instructions.h" - -#include "EnzymeLogic.h" -#include "TraceUtils.h" -#include "Utils.h" - -class TraceGenerator final : public llvm::InstVisitor { -private: - EnzymeLogic &Logic; - TraceUtils *const tutils; - ProbProgMode mode = tutils->mode; - bool autodiff; - llvm::ValueMap &originalToNewFn; - const llvm::SmallPtrSetImpl &generativeFunctions; - const llvm::StringSet<> &activeRandomVariables; - -public: - TraceGenerator( - EnzymeLogic &Logic, TraceUtils *tutils, bool autodiff, - llvm::ValueMap - &originalToNewFn, - const llvm::SmallPtrSetImpl &generativeFunctions, - const llvm::StringSet<> &activeRandomVariables); - - void visitFunction(llvm::Function &F); - - void handleSampleCall(llvm::CallInst &call, llvm::CallInst *new_call); - - void handleObserveCall(llvm::CallInst &call, llvm::CallInst *new_call); - - void handleArbitraryCall(llvm::CallInst &call, llvm::CallInst *new_call); - - void visitCallInst(llvm::CallInst &call); - - void visitReturnInst(llvm::ReturnInst &ret); -}; - -#endif /* TraceGenerator_h */ diff --git a/enzyme/Enzyme/TraceInterface.cpp b/enzyme/Enzyme/TraceInterface.cpp deleted file mode 100644 index 9d6d6e4563a2..000000000000 --- a/enzyme/Enzyme/TraceInterface.cpp +++ /dev/null @@ -1,449 +0,0 @@ -//===- TraceInterface.h - Interact with probabilistic programming traces -//---===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains an abstraction for static and dynamic implementations of -// the probabilistic programming interface. -// -//===----------------------------------------------------------------------===// - -#include "TraceInterface.h" - -#include "Utils.h" - -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -using namespace llvm; - -TraceInterface::TraceInterface(LLVMContext &C) : C(C){}; - -PointerType *traceType(LLVMContext &C) { - return getDefaultAnonymousTapeType(C); -} - -Type *addressType(LLVMContext &C) { return getInt8PtrTy(C); } - -IntegerType *TraceInterface::sizeType(LLVMContext &C) { - return IntegerType::getInt64Ty(C); -} - -Type *TraceInterface::stringType(LLVMContext &C) { return getInt8PtrTy(C); } - -FunctionType *TraceInterface::getTraceTy() { return getTraceTy(C); } -FunctionType *TraceInterface::getChoiceTy() { return getChoiceTy(C); } -FunctionType *TraceInterface::insertCallTy() { return insertCallTy(C); } -FunctionType *TraceInterface::insertChoiceTy() { return insertChoiceTy(C); } -FunctionType *TraceInterface::insertArgumentTy() { return insertArgumentTy(C); } -FunctionType *TraceInterface::insertReturnTy() { return insertReturnTy(C); } -FunctionType *TraceInterface::insertFunctionTy() { return insertFunctionTy(C); } -FunctionType *TraceInterface::insertChoiceGradientTy() { - return insertChoiceGradientTy(C); -} -FunctionType *TraceInterface::insertArgumentGradientTy() { - return insertArgumentGradientTy(C); -} -FunctionType *TraceInterface::newTraceTy() { return newTraceTy(C); } -FunctionType *TraceInterface::freeTraceTy() { return freeTraceTy(C); } -FunctionType *TraceInterface::hasCallTy() { return hasCallTy(C); } -FunctionType *TraceInterface::hasChoiceTy() { return hasChoiceTy(C); } - -FunctionType *TraceInterface::getTraceTy(LLVMContext &C) { - return FunctionType::get(traceType(C), {traceType(C), stringType(C)}, false); -} - -FunctionType *TraceInterface::getChoiceTy(LLVMContext &C) { - return FunctionType::get( - sizeType(C), {traceType(C), stringType(C), addressType(C), sizeType(C)}, - false); -} - -FunctionType *TraceInterface::insertCallTy(LLVMContext &C) { - return FunctionType::get(Type::getVoidTy(C), - {getInt8PtrTy(C), stringType(C), getInt8PtrTy(C)}, - false); -} - -FunctionType *TraceInterface::insertChoiceTy(LLVMContext &C) { - return FunctionType::get(Type::getVoidTy(C), - {getInt8PtrTy(C), stringType(C), - Type::getDoubleTy(C), getInt8PtrTy(C), sizeType(C)}, - false); -} - -FunctionType *TraceInterface::insertArgumentTy(LLVMContext &C) { - return FunctionType::get( - Type::getVoidTy(C), - {getInt8PtrTy(C), stringType(C), getInt8PtrTy(C), sizeType(C)}, false); -} - -FunctionType *TraceInterface::insertReturnTy(LLVMContext &C) { - return FunctionType::get(Type::getVoidTy(C), - {getInt8PtrTy(C), getInt8PtrTy(C), sizeType(C)}, - false); -} - -FunctionType *TraceInterface::insertFunctionTy(LLVMContext &C) { - return FunctionType::get(Type::getVoidTy(C), - {getInt8PtrTy(C), getInt8PtrTy(C)}, false); -} - -FunctionType *TraceInterface::insertChoiceGradientTy(LLVMContext &C) { - return FunctionType::get( - Type::getVoidTy(C), - {getInt8PtrTy(C), stringType(C), getInt8PtrTy(C), sizeType(C)}, false); -} - -FunctionType *TraceInterface::insertArgumentGradientTy(LLVMContext &C) { - return FunctionType::get( - Type::getVoidTy(C), - {getInt8PtrTy(C), stringType(C), getInt8PtrTy(C), sizeType(C)}, false); -} - -FunctionType *TraceInterface::newTraceTy(LLVMContext &C) { - return FunctionType::get(getInt8PtrTy(C), {}, false); -} - -FunctionType *TraceInterface::freeTraceTy(LLVMContext &C) { - return FunctionType::get(Type::getVoidTy(C), {getInt8PtrTy(C)}, false); -} - -FunctionType *TraceInterface::hasCallTy(LLVMContext &C) { - return FunctionType::get(Type::getInt1Ty(C), {getInt8PtrTy(C), stringType(C)}, - false); -} - -FunctionType *TraceInterface::hasChoiceTy(LLVMContext &C) { - return FunctionType::get(Type::getInt1Ty(C), {getInt8PtrTy(C), stringType(C)}, - false); -} - -StaticTraceInterface::StaticTraceInterface(Module *M) - : TraceInterface(M->getContext()) { - for (auto &&F : M->functions()) { - if (F.isIntrinsic()) - continue; - if (F.getName().contains("__enzyme_newtrace")) { - assert(F.getFunctionType() == newTraceTy()); - newTraceFunction = &F; - } else if (F.getName().contains("__enzyme_freetrace")) { - assert(F.getFunctionType() == freeTraceTy()); - freeTraceFunction = &F; - } else if (F.getName().contains("__enzyme_get_trace")) { - assert(F.getFunctionType() == getTraceTy()); - getTraceFunction = &F; - } else if (F.getName().contains("__enzyme_get_choice")) { - assert(F.getFunctionType() == getChoiceTy()); - getChoiceFunction = &F; - } else if (F.getName().contains("__enzyme_insert_call")) { - assert(F.getFunctionType() == insertCallTy()); - insertCallFunction = &F; - } else if (F.getName().contains("__enzyme_insert_choice")) { - assert(F.getFunctionType() == insertChoiceTy()); - insertChoiceFunction = &F; - } else if (F.getName().contains("__enzyme_insert_argument")) { - assert(F.getFunctionType() == insertArgumentTy()); - insertArgumentFunction = &F; - } else if (F.getName().contains("__enzyme_insert_return")) { - assert(F.getFunctionType() == insertReturnTy()); - insertReturnFunction = &F; - } else if (F.getName().contains("__enzyme_insert_function")) { - assert(F.getFunctionType() == insertFunctionTy()); - insertFunctionFunction = &F; - } else if (F.getName().contains("__enzyme_insert_gradient_choice")) { - assert(F.getFunctionType() == insertChoiceGradientTy()); - insertChoiceGradientFunction = &F; - } else if (F.getName().contains("__enzyme_insert_gradient_argument")) { - assert(F.getFunctionType() == insertArgumentGradientTy()); - insertArgumentGradientFunction = &F; - } else if (F.getName().contains("__enzyme_has_call")) { - assert(F.getFunctionType() == hasCallTy()); - hasCallFunction = &F; - } else if (F.getName().contains("__enzyme_has_choice")) { - assert(F.getFunctionType() == hasChoiceTy()); - hasChoiceFunction = &F; - } - } - - assert(newTraceFunction); - assert(freeTraceFunction); - assert(getTraceFunction); - assert(getChoiceFunction); - assert(insertCallFunction); - assert(insertChoiceFunction); - - assert(insertArgumentFunction); - assert(insertReturnFunction); - assert(insertFunctionFunction); - - assert(insertChoiceGradientFunction); - assert(insertArgumentGradientFunction); - - assert(hasCallFunction); - assert(hasChoiceFunction); - - newTraceFunction->addFnAttr("enzyme_notypeanalysis"); - freeTraceFunction->addFnAttr("enzyme_notypeanalysis"); - getTraceFunction->addFnAttr("enzyme_notypeanalysis"); - getChoiceFunction->addFnAttr("enzyme_notypeanalysis"); - insertCallFunction->addFnAttr("enzyme_notypeanalysis"); - insertChoiceFunction->addFnAttr("enzyme_notypeanalysis"); - insertArgumentFunction->addFnAttr("enzyme_notypeanalysis"); - insertReturnFunction->addFnAttr("enzyme_notypeanalysis"); - insertFunctionFunction->addFnAttr("enzyme_notypeanalysis"); - insertChoiceGradientFunction->addFnAttr("enzyme_notypeanalysis"); - insertArgumentGradientFunction->addFnAttr("enzyme_notypeanalysis"); - hasCallFunction->addFnAttr("enzyme_notypeanalysis"); - hasChoiceFunction->addFnAttr("enzyme_notypeanalysis"); - - newTraceFunction->addFnAttr("enzyme_inactive"); - freeTraceFunction->addFnAttr("enzyme_inactive"); - getTraceFunction->addFnAttr("enzyme_inactive"); - getChoiceFunction->addFnAttr("enzyme_inactive"); - insertCallFunction->addFnAttr("enzyme_inactive"); - insertChoiceFunction->addFnAttr("enzyme_inactive"); - insertArgumentFunction->addFnAttr("enzyme_inactive"); - insertReturnFunction->addFnAttr("enzyme_inactive"); - insertFunctionFunction->addFnAttr("enzyme_inactive"); - insertChoiceGradientFunction->addFnAttr("enzyme_inactive"); - insertArgumentGradientFunction->addFnAttr("enzyme_inactive"); - hasCallFunction->addFnAttr("enzyme_inactive"); - hasChoiceFunction->addFnAttr("enzyme_inactive"); - - newTraceFunction->addFnAttr(Attribute::NoFree); - getTraceFunction->addFnAttr(Attribute::NoFree); - getChoiceFunction->addFnAttr(Attribute::NoFree); - insertCallFunction->addFnAttr(Attribute::NoFree); - insertChoiceFunction->addFnAttr(Attribute::NoFree); - insertArgumentFunction->addFnAttr(Attribute::NoFree); - insertReturnFunction->addFnAttr(Attribute::NoFree); - insertFunctionFunction->addFnAttr(Attribute::NoFree); - insertChoiceGradientFunction->addFnAttr(Attribute::NoFree); - insertArgumentGradientFunction->addFnAttr(Attribute::NoFree); - hasCallFunction->addFnAttr(Attribute::NoFree); - hasChoiceFunction->addFnAttr(Attribute::NoFree); -} - -StaticTraceInterface::StaticTraceInterface( - LLVMContext &C, Function *getTraceFunction, Function *getChoiceFunction, - Function *insertCallFunction, Function *insertChoiceFunction, - Function *insertArgumentFunction, Function *insertReturnFunction, - Function *insertFunctionFunction, Function *insertChoiceGradientFunction, - Function *insertArgumentGradientFunction, Function *newTraceFunction, - Function *freeTraceFunction, Function *hasCallFunction, - Function *hasChoiceFunction) - : TraceInterface(C), getTraceFunction(getTraceFunction), - getChoiceFunction(getChoiceFunction), - insertCallFunction(insertCallFunction), - insertChoiceFunction(insertChoiceFunction), - insertArgumentFunction(insertArgumentFunction), - insertReturnFunction(insertReturnFunction), - insertFunctionFunction(insertFunctionFunction), - insertChoiceGradientFunction(insertChoiceGradientFunction), - insertArgumentGradientFunction(insertArgumentGradientFunction), - newTraceFunction(newTraceFunction), freeTraceFunction(freeTraceFunction), - hasCallFunction(hasCallFunction), hasChoiceFunction(hasChoiceFunction){}; - -// user implemented -Value *StaticTraceInterface::getTrace(IRBuilder<> &Builder) { - return getTraceFunction; -} -Value *StaticTraceInterface::getChoice(IRBuilder<> &Builder) { - return getChoiceFunction; -} -Value *StaticTraceInterface::insertCall(IRBuilder<> &Builder) { - return insertCallFunction; -} -Value *StaticTraceInterface::insertChoice(IRBuilder<> &Builder) { - return insertChoiceFunction; -} -Value *StaticTraceInterface::insertArgument(IRBuilder<> &Builder) { - return insertArgumentFunction; -} -Value *StaticTraceInterface::insertReturn(IRBuilder<> &Builder) { - return insertReturnFunction; -} -Value *StaticTraceInterface::insertFunction(IRBuilder<> &Builder) { - return insertFunctionFunction; -} -Value *StaticTraceInterface::insertChoiceGradient(IRBuilder<> &Builder) { - return insertChoiceGradientFunction; -} -Value *StaticTraceInterface::insertArgumentGradient(IRBuilder<> &Builder) { - return insertArgumentGradientFunction; -} -Value *StaticTraceInterface::newTrace(IRBuilder<> &Builder) { - return newTraceFunction; -} -Value *StaticTraceInterface::freeTrace(IRBuilder<> &Builder) { - return freeTraceFunction; -} -Value *StaticTraceInterface::hasCall(IRBuilder<> &Builder) { - return hasCallFunction; -} -Value *StaticTraceInterface::hasChoice(IRBuilder<> &Builder) { - return hasChoiceFunction; -} - -DynamicTraceInterface::DynamicTraceInterface(Value *dynamicInterface, - Function *F) - : TraceInterface(F->getContext()) { - assert(dynamicInterface); - - auto &M = *F->getParent(); - IRBuilder<> Builder(getFirstNonPHIOrDbg(&F->getEntryBlock())); - - getTraceFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, getTraceTy(), 0, M, "get_trace"); - getChoiceFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, getChoiceTy(), 1, M, "get_choice"); - insertCallFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, insertCallTy(), 2, M, "insert_call"); - insertChoiceFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, insertChoiceTy(), 3, M, "insert_choice"); - insertArgumentFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, insertArgumentTy(), 4, M, "insert_argument"); - insertReturnFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, insertReturnTy(), 5, M, "insert_return"); - insertFunctionFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, insertFunctionTy(), 6, M, "insert_function"); - insertChoiceGradientFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, insertChoiceGradientTy(), 7, M, - "insert_choice_gradient"); - insertArgumentGradientFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, insertArgumentGradientTy(), 8, M, - "insert_argument_gradient"); - newTraceFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, newTraceTy(), 9, M, "new_trace"); - freeTraceFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, freeTraceTy(), 10, M, "free_trace"); - hasCallFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, hasCallTy(), 11, M, "has_call"); - hasChoiceFunction = MaterializeInterfaceFunction( - Builder, dynamicInterface, hasChoiceTy(), 12, M, "has_choice"); - - assert(newTraceFunction); - assert(freeTraceFunction); - assert(getTraceFunction); - assert(getChoiceFunction); - assert(insertCallFunction); - assert(insertChoiceFunction); - - assert(insertArgumentFunction); - assert(insertReturnFunction); - assert(insertFunctionFunction); - - assert(insertChoiceGradientFunction); - assert(insertArgumentGradientFunction); - - assert(hasCallFunction); - assert(hasChoiceFunction); -} - -Function *DynamicTraceInterface::MaterializeInterfaceFunction( - IRBuilder<> &Builder, Value *dynamicInterface, FunctionType *FTy, - unsigned index, Module &M, const Twine &Name) { - auto ptr = - Builder.CreateInBoundsGEP(getInt8PtrTy(dynamicInterface->getContext()), - dynamicInterface, Builder.getInt32(index)); - auto load = - Builder.CreateLoad(getInt8PtrTy(dynamicInterface->getContext()), ptr); - auto pty = PointerType::get(FTy, load->getPointerAddressSpace()); - auto cast = Builder.CreatePointerCast(load, pty); - - auto global = - new GlobalVariable(M, pty, false, GlobalVariable::PrivateLinkage, - ConstantPointerNull::get(pty), Name + "_ptr"); - Builder.CreateStore(cast, global); - - Function *F = Function::Create(FTy, Function::PrivateLinkage, Name, M); - F->addFnAttr(Attribute::AlwaysInline); - BasicBlock *Entry = BasicBlock::Create(M.getContext(), "entry", F); - - IRBuilder<> WrapperBuilder(Entry); - - auto ToWrap = WrapperBuilder.CreateLoad(pty, global, Name); - auto Args = SmallVector(make_pointer_range(F->args())); - auto Call = WrapperBuilder.CreateCall(FTy, ToWrap, Args); - - if (!FTy->getReturnType()->isVoidTy()) { - WrapperBuilder.CreateRet(Call); - } else { - WrapperBuilder.CreateRetVoid(); - } - - return F; -} - -// user implemented -Value *DynamicTraceInterface::getTrace(IRBuilder<> &Builder) { - return getTraceFunction; -} - -Value *DynamicTraceInterface::getChoice(IRBuilder<> &Builder) { - return getChoiceFunction; -} - -Value *DynamicTraceInterface::insertCall(IRBuilder<> &Builder) { - return insertCallFunction; -} - -Value *DynamicTraceInterface::insertChoice(IRBuilder<> &Builder) { - return insertChoiceFunction; -} - -Value *DynamicTraceInterface::insertArgument(IRBuilder<> &Builder) { - return insertArgumentFunction; -} - -Value *DynamicTraceInterface::insertReturn(IRBuilder<> &Builder) { - return insertReturnFunction; -} - -Value *DynamicTraceInterface::insertFunction(IRBuilder<> &Builder) { - return insertFunctionFunction; -} - -Value *DynamicTraceInterface::insertChoiceGradient(IRBuilder<> &Builder) { - return insertChoiceGradientFunction; -} - -Value *DynamicTraceInterface::insertArgumentGradient(IRBuilder<> &Builder) { - return insertArgumentGradientFunction; -} - -Value *DynamicTraceInterface::newTrace(IRBuilder<> &Builder) { - return newTraceFunction; -} - -Value *DynamicTraceInterface::freeTrace(IRBuilder<> &Builder) { - return freeTraceFunction; -} - -Value *DynamicTraceInterface::hasCall(IRBuilder<> &Builder) { - return hasCallFunction; -} - -Value *DynamicTraceInterface::hasChoice(IRBuilder<> &Builder) { - return hasChoiceFunction; -} diff --git a/enzyme/Enzyme/TraceInterface.h b/enzyme/Enzyme/TraceInterface.h deleted file mode 100644 index eeca7d1abc35..000000000000 --- a/enzyme/Enzyme/TraceInterface.h +++ /dev/null @@ -1,197 +0,0 @@ -//===- TraceInterface.h - Interact with probabilistic programming traces -//---===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains an abstraction for static and dynamic implementations of -// the probabilistic programming interface. -// -//===----------------------------------------------------------------------===//---------------------------------------------------------------------===// - -#ifndef TraceInterface_h -#define TraceInterface_h - -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -class TraceInterface { -private: - llvm::LLVMContext &C; - -public: - TraceInterface(llvm::LLVMContext &C); - - virtual ~TraceInterface() = default; - -public: - // user implemented - virtual llvm::Value *getTrace(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *getChoice(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *insertCall(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *insertChoice(llvm::IRBuilder<> &Builder) = 0; - - virtual llvm::Value *insertArgument(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *insertReturn(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *insertFunction(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *insertChoiceGradient(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *insertArgumentGradient(llvm::IRBuilder<> &Builder) = 0; - - virtual llvm::Value *newTrace(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *freeTrace(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *hasCall(llvm::IRBuilder<> &Builder) = 0; - virtual llvm::Value *hasChoice(llvm::IRBuilder<> &Builder) = 0; - -public: - static llvm::IntegerType *sizeType(llvm::LLVMContext &C); - static llvm::Type *stringType(llvm::LLVMContext &C); - -public: - llvm::FunctionType *getTraceTy(); - llvm::FunctionType *getChoiceTy(); - llvm::FunctionType *insertCallTy(); - llvm::FunctionType *insertChoiceTy(); - - llvm::FunctionType *insertArgumentTy(); - llvm::FunctionType *insertReturnTy(); - llvm::FunctionType *insertFunctionTy(); - llvm::FunctionType *insertChoiceGradientTy(); - llvm::FunctionType *insertArgumentGradientTy(); - - llvm::FunctionType *newTraceTy(); - llvm::FunctionType *freeTraceTy(); - llvm::FunctionType *hasCallTy(); - llvm::FunctionType *hasChoiceTy(); - - static llvm::FunctionType *getTraceTy(llvm::LLVMContext &C); - static llvm::FunctionType *getChoiceTy(llvm::LLVMContext &C); - static llvm::FunctionType *insertCallTy(llvm::LLVMContext &C); - static llvm::FunctionType *insertChoiceTy(llvm::LLVMContext &C); - - static llvm::FunctionType *insertArgumentTy(llvm::LLVMContext &C); - static llvm::FunctionType *insertReturnTy(llvm::LLVMContext &C); - static llvm::FunctionType *insertFunctionTy(llvm::LLVMContext &C); - static llvm::FunctionType *insertChoiceGradientTy(llvm::LLVMContext &C); - static llvm::FunctionType *insertArgumentGradientTy(llvm::LLVMContext &C); - - static llvm::FunctionType *newTraceTy(llvm::LLVMContext &C); - static llvm::FunctionType *freeTraceTy(llvm::LLVMContext &C); - static llvm::FunctionType *hasCallTy(llvm::LLVMContext &C); - static llvm::FunctionType *hasChoiceTy(llvm::LLVMContext &C); -}; - -class StaticTraceInterface final : public TraceInterface { -private: - llvm::Function *getTraceFunction = nullptr; - llvm::Function *getChoiceFunction = nullptr; - llvm::Function *insertCallFunction = nullptr; - llvm::Function *insertChoiceFunction = nullptr; - llvm::Function *insertArgumentFunction = nullptr; - llvm::Function *insertReturnFunction = nullptr; - llvm::Function *insertFunctionFunction = nullptr; - llvm::Function *insertChoiceGradientFunction = nullptr; - llvm::Function *insertArgumentGradientFunction = nullptr; - llvm::Function *newTraceFunction = nullptr; - llvm::Function *freeTraceFunction = nullptr; - llvm::Function *hasCallFunction = nullptr; - llvm::Function *hasChoiceFunction = nullptr; - -public: - StaticTraceInterface(llvm::Module *M); - - StaticTraceInterface(llvm::LLVMContext &C, llvm::Function *getTraceFunction, - llvm::Function *getChoiceFunction, - llvm::Function *insertCallFunction, - llvm::Function *insertChoiceFunction, - llvm::Function *insertArgumentFunction, - llvm::Function *insertReturnFunction, - llvm::Function *insertFunctionFunction, - llvm::Function *insertChoiceGradientFunction, - llvm::Function *insertArgumentGradientFunction, - llvm::Function *newTraceFunction, - llvm::Function *freeTraceFunction, - llvm::Function *hasCallFunction, - llvm::Function *hasChoiceFunction); - - ~StaticTraceInterface() = default; - -public: - // user implemented - llvm::Value *getTrace(llvm::IRBuilder<> &Builder); - llvm::Value *getChoice(llvm::IRBuilder<> &Builder); - llvm::Value *insertCall(llvm::IRBuilder<> &Builder); - llvm::Value *insertChoice(llvm::IRBuilder<> &Builder); - llvm::Value *insertArgument(llvm::IRBuilder<> &Builder); - llvm::Value *insertReturn(llvm::IRBuilder<> &Builder); - llvm::Value *insertFunction(llvm::IRBuilder<> &Builder); - llvm::Value *insertChoiceGradient(llvm::IRBuilder<> &Builder); - llvm::Value *insertArgumentGradient(llvm::IRBuilder<> &Builder); - llvm::Value *newTrace(llvm::IRBuilder<> &Builder); - llvm::Value *freeTrace(llvm::IRBuilder<> &Builder); - llvm::Value *hasCall(llvm::IRBuilder<> &Builder); - llvm::Value *hasChoice(llvm::IRBuilder<> &Builder); -}; - -class DynamicTraceInterface final : public TraceInterface { -private: - llvm::Function *getTraceFunction; - llvm::Function *getChoiceFunction; - llvm::Function *insertCallFunction; - llvm::Function *insertChoiceFunction; - llvm::Function *insertArgumentFunction; - llvm::Function *insertReturnFunction; - llvm::Function *insertFunctionFunction; - llvm::Function *insertChoiceGradientFunction; - llvm::Function *insertArgumentGradientFunction; - llvm::Function *newTraceFunction; - llvm::Function *freeTraceFunction; - llvm::Function *hasCallFunction; - llvm::Function *hasChoiceFunction; - -public: - DynamicTraceInterface(llvm::Value *dynamicInterface, llvm::Function *F); - - ~DynamicTraceInterface() = default; - -private: - llvm::Function *MaterializeInterfaceFunction(llvm::IRBuilder<> &Builder, - llvm::Value *, - llvm::FunctionType *, - unsigned index, llvm::Module &M, - const llvm::Twine &Name = ""); - -public: - // user implemented - llvm::Value *getTrace(llvm::IRBuilder<> &Builder); - llvm::Value *getChoice(llvm::IRBuilder<> &Builder); - llvm::Value *insertCall(llvm::IRBuilder<> &Builder); - llvm::Value *insertChoice(llvm::IRBuilder<> &Builder); - llvm::Value *insertArgument(llvm::IRBuilder<> &Builder); - llvm::Value *insertReturn(llvm::IRBuilder<> &Builder); - llvm::Value *insertFunction(llvm::IRBuilder<> &Builder); - llvm::Value *insertChoiceGradient(llvm::IRBuilder<> &Builder); - llvm::Value *insertArgumentGradient(llvm::IRBuilder<> &Builder); - llvm::Value *newTrace(llvm::IRBuilder<> &Builder); - llvm::Value *freeTrace(llvm::IRBuilder<> &Builder); - llvm::Value *hasCall(llvm::IRBuilder<> &Builder); - llvm::Value *hasChoice(llvm::IRBuilder<> &Builder); -}; - -#endif diff --git a/enzyme/Enzyme/TraceUtils.cpp b/enzyme/Enzyme/TraceUtils.cpp deleted file mode 100644 index 9040e37f6eea..000000000000 --- a/enzyme/Enzyme/TraceUtils.cpp +++ /dev/null @@ -1,526 +0,0 @@ -//===- TraceUtils.cpp - Utilites for interacting with traces ------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains utilities for interacting with probabilistic programming -// traces using the probabilistic programming -// trace interface -// -//===----------------------------------------------------------------------===// - -#include "TraceUtils.h" - -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/User.h" -#include "llvm/IR/Value.h" -#include "llvm/IR/ValueMap.h" - -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" - -#include "TraceInterface.h" - -using namespace llvm; - -TraceUtils::TraceUtils(ProbProgMode mode, - const SmallPtrSetImpl &sampleFunctions, - const SmallPtrSetImpl &observeFunctions, - Function *newFunc, Argument *trace, - Argument *observations, Argument *likelihood, - TraceInterface *interface) - : trace(trace), observations(observations), likelihood(likelihood), - interface(interface), mode(mode), newFunc(newFunc), - sampleFunctions(sampleFunctions.begin(), sampleFunctions.end()), - observeFunctions(observeFunctions.begin(), observeFunctions.end()){}; - -TraceUtils * -TraceUtils::FromClone(ProbProgMode mode, - const SmallPtrSetImpl &sampleFunctions, - const SmallPtrSetImpl &observeFunctions, - TraceInterface *interface, Function *oldFunc, - ValueToValueMapTy &originalToNewFn) { - auto &Context = oldFunc->getContext(); - FunctionType *orig_FTy = oldFunc->getFunctionType(); - SmallVector params; - - for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) { - params.push_back(orig_FTy->getParamType(i)); - } - - Type *likelihood_acc_type = - PointerType::getUnqual(Type::getDoubleTy(Context)); - params.push_back(likelihood_acc_type); - - if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { - Type *traceType = interface->getTraceTy()->getReturnType(); - - if (mode == ProbProgMode::Condition) - params.push_back(traceType); - - params.push_back(traceType); - } - - Type *RetTy = oldFunc->getReturnType(); - FunctionType *FTy = FunctionType::get(RetTy, params, oldFunc->isVarArg()); - - const char *mode_str; - switch (mode) { - case ProbProgMode::Likelihood: - mode_str = "likelihood"; - break; - case ProbProgMode::Trace: - mode_str = "trace"; - break; - case ProbProgMode::Condition: - mode_str = "condition"; - break; - } - - Function *newFunc = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - Twine(mode_str) + "_" + oldFunc->getName(), oldFunc->getParent()); - - auto DestArg = newFunc->arg_begin(); - auto SrcArg = oldFunc->arg_begin(); - - for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) { - Argument *arg = SrcArg; - originalToNewFn[arg] = DestArg; - DestArg->setName(arg->getName()); - DestArg++; - SrcArg++; - } - - SmallVector Returns; - if (!oldFunc->empty()) { -#if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto(newFunc, oldFunc, originalToNewFn, - CloneFunctionChangeType::LocalChangesOnly, Returns, "", - nullptr); -#else - CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "", - nullptr); -#endif - } - if (newFunc->empty()) { - auto entry = BasicBlock::Create(newFunc->getContext(), "entry", newFunc); - IRBuilder<> B(entry); - B.CreateUnreachable(); - } - - newFunc->setLinkage(Function::LinkageTypes::InternalLinkage); - - Argument *trace = nullptr; - Argument *observations = nullptr; - Argument *likelihood = nullptr; - - auto arg = newFunc->arg_end(); - - if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { - arg -= 1; - trace = arg; - arg->setName("trace"); - arg->addAttr(Attribute::get(Context, TraceParameterAttribute)); - } - - if (mode == ProbProgMode::Condition) { - arg -= 1; - observations = arg; - arg->setName("observations"); - arg->addAttr(Attribute::get(Context, ObservationsParameterAttribute)); - } - - arg -= 1; - likelihood = arg; - arg->setName("likelihood"); - arg->addAttr(Attribute::get(Context, LikelihoodParameterAttribute)); - - return new TraceUtils(mode, sampleFunctions, observeFunctions, newFunc, trace, - observations, likelihood, interface); -}; - -TraceUtils::~TraceUtils() = default; - -TraceInterface *TraceUtils::getTraceInterface() { return interface; } - -Value *TraceUtils::getTrace() { return trace; } - -Value *TraceUtils::getObservations() { return observations; } - -Value *TraceUtils::getLikelihood() { return likelihood; } - -std::pair -TraceUtils::ValueToVoidPtrAndSize(IRBuilder<> &Builder, Value *val, - Type *size_type) { - auto valsize = val->getType()->getPrimitiveSizeInBits(); - - if (val->getType()->isPointerTy()) { - Value *retval = - Builder.CreatePointerCast(val, getInt8PtrTy(val->getContext())); - return {retval, ConstantInt::get(size_type, valsize / 8)}; - } - - auto M = Builder.GetInsertBlock()->getModule(); - auto &DL = M->getDataLayout(); - auto pointersize = DL.getPointerSizeInBits(); - - if (valsize <= pointersize) { - auto cast = - Builder.CreateBitCast(val, IntegerType::get(M->getContext(), valsize)); - if (valsize != pointersize) - cast = Builder.CreateZExt(cast, Builder.getIntPtrTy(DL)); - - Value *retval = - Builder.CreateIntToPtr(cast, getInt8PtrTy(cast->getContext())); - return {retval, ConstantInt::get(size_type, valsize / 8)}; - } else { - IRBuilder<> AllocaBuilder(getFirstNonPHIOrDbgOrLifetime( - &Builder.GetInsertBlock()->getParent()->getEntryBlock())); - auto alloca = AllocaBuilder.CreateAlloca(val->getType(), nullptr, - val->getName() + ".ptr"); - Builder.CreateStore(val, alloca); - return {alloca, ConstantInt::get(size_type, valsize / 8)}; - } -} - -CallInst *TraceUtils::CreateTrace(IRBuilder<> &Builder, const Twine &Name) { - auto call = Builder.CreateCall(interface->newTraceTy(), - interface->newTrace(Builder), {}, Name); -#if LLVM_VERSION_MAJOR >= 14 - call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call->getContext(), "enzyme_newtrace")); -#else - call->addAttribute(AttributeList::FunctionIndex, - Attribute::get(call->getContext(), "enzyme_newtrace")); - -#endif - return call; -} - -CallInst *TraceUtils::FreeTrace(IRBuilder<> &Builder) { - auto call = Builder.CreateCall(interface->freeTraceTy(), - interface->freeTrace(Builder), {trace}); -#if LLVM_VERSION_MAJOR >= 14 - call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call->getContext(), "enzyme_freetrace")); -#else - call->addAttribute(AttributeList::FunctionIndex, - Attribute::get(call->getContext(), "enzyme_freetrace")); - -#endif - return call; -} - -CallInst *TraceUtils::InsertChoice(IRBuilder<> &Builder, Value *address, - Value *score, Value *choice) { - Type *size_type = interface->insertChoiceTy()->getParamType(4); - auto &&[retval, sizeval] = ValueToVoidPtrAndSize(Builder, choice, size_type); - - Value *args[] = {trace, address, score, retval, sizeval}; - - auto call = Builder.CreateCall(interface->insertChoiceTy(), - interface->insertChoice(Builder), args); - call->addParamAttr(1, Attribute::ReadOnly); - - addCallSiteNoCapture(call, 1); - return call; -} - -CallInst *TraceUtils::InsertCall(IRBuilder<> &Builder, Value *address, - Value *subtrace) { - Value *args[] = {trace, address, subtrace}; - - auto call = Builder.CreateCall(interface->insertCallTy(), - interface->insertCall(Builder), args); - call->addParamAttr(1, Attribute::ReadOnly); - addCallSiteNoCapture(call, 1); -#if LLVM_VERSION_MAJOR >= 14 - call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call->getContext(), "enzyme_insert_call")); -#else - call->addAttribute(AttributeList::FunctionIndex, - Attribute::get(call->getContext(), "enzyme_insert_call")); - -#endif - return call; -} - -CallInst *TraceUtils::InsertArgument(IRBuilder<> &Builder, Value *name, - Value *argument) { - Type *size_type = interface->insertArgumentTy()->getParamType(3); - auto &&[retval, sizeval] = - ValueToVoidPtrAndSize(Builder, argument, size_type); - - Value *args[] = {trace, name, retval, sizeval}; - - auto call = Builder.CreateCall(interface->insertArgumentTy(), - interface->insertArgument(Builder), args); - call->addParamAttr(1, Attribute::ReadOnly); - addCallSiteNoCapture(call, 1); - return call; -} - -CallInst *TraceUtils::InsertReturn(IRBuilder<> &Builder, Value *val) { - Type *size_type = interface->insertReturnTy()->getParamType(2); - auto &&[retval, sizeval] = ValueToVoidPtrAndSize(Builder, val, size_type); - - Value *args[] = {trace, retval, sizeval}; - - auto call = Builder.CreateCall(interface->insertReturnTy(), - interface->insertReturn(Builder), args); - return call; -} - -CallInst *TraceUtils::InsertFunction(IRBuilder<> &Builder, Function *function) { - assert(!function->isIntrinsic()); - auto FunctionPtr = - Builder.CreateBitCast(function, getInt8PtrTy(function->getContext())); - - Value *args[] = {trace, FunctionPtr}; - - auto call = Builder.CreateCall(interface->insertFunctionTy(), - interface->insertFunction(Builder), args); - return call; -} - -CallInst *TraceUtils::InsertChoiceGradient(IRBuilder<> &Builder, - FunctionType *interface_type, - Value *interface_function, - Value *address, Value *choice, - Value *trace) { - Type *size_type = interface_type->getParamType(3); - auto &&[retval, sizeval] = ValueToVoidPtrAndSize(Builder, choice, size_type); - - Value *args[] = {trace, address, retval, sizeval}; - - auto call = Builder.CreateCall(interface_type, interface_function, args); - call->addParamAttr(1, Attribute::ReadOnly); - addCallSiteNoCapture(call, 1); - return call; -} - -CallInst *TraceUtils::InsertArgumentGradient(IRBuilder<> &Builder, - FunctionType *interface_type, - Value *interface_function, - Value *name, Value *argument, - Value *trace) { - Type *size_type = interface_type->getParamType(3); - auto &&[retval, sizeval] = - ValueToVoidPtrAndSize(Builder, argument, size_type); - - Value *args[] = {trace, name, retval, sizeval}; - - auto call = Builder.CreateCall(interface_type, interface_function, args); - call->addParamAttr(1, Attribute::ReadOnly); - addCallSiteNoCapture(call, 1); - return call; -} - -CallInst *TraceUtils::GetTrace(IRBuilder<> &Builder, Value *address, - const Twine &Name) { - assert(address->getType()->isPointerTy()); - - Value *args[] = {observations, address}; - - auto call = Builder.CreateCall(interface->getTraceTy(), - interface->getTrace(Builder), args, Name); - call->addParamAttr(1, Attribute::ReadOnly); - addCallSiteNoCapture(call, 1); - return call; -} - -Instruction *TraceUtils::GetChoice(IRBuilder<> &Builder, Value *address, - Type *choiceType, const Twine &Name) { - IRBuilder<> AllocaBuilder(getFirstNonPHIOrDbgOrLifetime( - &Builder.GetInsertBlock()->getParent()->getEntryBlock())); - AllocaInst *store_dest = - AllocaBuilder.CreateAlloca(choiceType, nullptr, Name + ".ptr"); - auto preallocated_size = choiceType->getPrimitiveSizeInBits() / 8; - Type *size_type = interface->getChoiceTy()->getParamType(3); - - Value *args[] = {observations, address, - Builder.CreatePointerCast( - store_dest, getInt8PtrTy(store_dest->getContext())), - ConstantInt::get(size_type, preallocated_size)}; - - auto call = - Builder.CreateCall(interface->getChoiceTy(), - interface->getChoice(Builder), args, Name + ".size"); - -#if LLVM_VERSION_MAJOR >= 14 - call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call->getContext(), "enzyme_inactive")); -#else - call->addAttribute(AttributeList::FunctionIndex, - Attribute::get(call->getContext(), "enzyme_inactive")); -#endif - call->addParamAttr(1, Attribute::ReadOnly); - addCallSiteNoCapture(call, 1); - return Builder.CreateLoad(choiceType, store_dest, "from.trace." + Name); -} - -Instruction *TraceUtils::HasChoice(IRBuilder<> &Builder, Value *address, - const Twine &Name) { - Value *args[]{observations, address}; - - auto call = Builder.CreateCall(interface->hasChoiceTy(), - interface->hasChoice(Builder), args, Name); - call->addParamAttr(1, Attribute::ReadOnly); - addCallSiteNoCapture(call, 1); - return call; -} - -Instruction *TraceUtils::HasCall(IRBuilder<> &Builder, Value *address, - const Twine &Name) { - Value *args[]{observations, address}; - - auto call = Builder.CreateCall(interface->hasCallTy(), - interface->hasCall(Builder), args, Name); - call->addParamAttr(1, Attribute::ReadOnly); - addCallSiteNoCapture(call, 1); - return call; -} - -Instruction *TraceUtils::SampleOrCondition(IRBuilder<> &Builder, - Function *sample_fn, - ArrayRef sample_args, - Value *address, const Twine &Name) { - auto &Context = Builder.getContext(); - auto parent_fn = Builder.GetInsertBlock()->getParent(); - - switch (mode) { - case ProbProgMode::Likelihood: - case ProbProgMode::Trace: { - auto sample_call = Builder.CreateCall(sample_fn->getFunctionType(), - sample_fn, sample_args); - return sample_call; - } - case ProbProgMode::Condition: { - Instruction *hasChoice = HasChoice(Builder, address, "has.choice." + Name); - - Value *ThenChoice, *ElseChoice; - BasicBlock *ThenBlock = BasicBlock::Create( - Context, "condition." + Name + ".with.trace", parent_fn); - BasicBlock *ElseBlock = BasicBlock::Create( - Context, "condition." + Name + ".without.trace", parent_fn); - BasicBlock *EndBlock = BasicBlock::Create(Context, "end", parent_fn); - - Builder.CreateCondBr(hasChoice, ThenBlock, ElseBlock); - Builder.SetInsertPoint(ThenBlock); - ThenChoice = GetChoice(Builder, address, sample_fn->getReturnType(), Name); - Builder.CreateBr(EndBlock); - - Builder.SetInsertPoint(ElseBlock); - ElseChoice = Builder.CreateCall(sample_fn->getFunctionType(), sample_fn, - sample_args, "sample." + Name); - Builder.CreateBr(EndBlock); - - Builder.SetInsertPoint(EndBlock); - auto phi = Builder.CreatePHI(sample_fn->getReturnType(), 2); - phi->addIncoming(ThenChoice, ThenBlock); - phi->addIncoming(ElseChoice, ElseBlock); - - return phi; - } - } - llvm_unreachable("Invalid sample_or_condition"); -} - -CallInst *TraceUtils::CreateOutlinedFunction( - IRBuilder<> &Builder, - function_ref &, TraceUtils *, ArrayRef)> Outlined, - Type *RetTy, ArrayRef Arguments, bool needsLikelihood, - const Twine &Name) { - SmallVector Tys; - SmallVector Vals; - Module *M = Builder.GetInsertBlock()->getModule(); - - for (auto Arg : Arguments) { - Vals.push_back(Arg); - Tys.push_back(Arg->getType()); - } - - if (needsLikelihood) { - Vals.push_back(likelihood); - Tys.push_back(likelihood->getType()); - } - - if (mode == ProbProgMode::Condition) { - Vals.push_back(observations); - Tys.push_back(observations->getType()); - } - - if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { - Vals.push_back(trace); - Tys.push_back(trace->getType()); - } - - FunctionType *FTy = FunctionType::get(RetTy, Tys, false); - Function *F = - Function::Create(FTy, Function::LinkageTypes::InternalLinkage, Name, M); - F->addFnAttr(Attribute::AlwaysInline); - - auto Entry = BasicBlock::Create(M->getContext(), "entry", F); - - auto ArgRange = make_pointer_range( - make_range(F->arg_begin(), F->arg_begin() + Arguments.size())); - SmallVector Rets(ArgRange); - - auto idx = F->arg_begin() + Arguments.size(); - - Argument *likelihood_arg = nullptr; - if (needsLikelihood) - likelihood_arg = idx++; - - Argument *observations_arg = nullptr; - if (mode == ProbProgMode::Condition) - observations_arg = idx++; - - Argument *trace_arg = nullptr; - if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) - trace_arg = idx++; - - TraceUtils OutlineTutils = - TraceUtils(mode, sampleFunctions, observeFunctions, F, trace_arg, - observations_arg, likelihood_arg, interface); - IRBuilder<> OutlineBuilder(Entry); - Outlined(OutlineBuilder, &OutlineTutils, Rets); - - return Builder.CreateCall(FTy, F, Vals); -} - -bool TraceUtils::isSampleCall(CallInst *call) { - auto F = getFunctionFromCall(call); - return sampleFunctions.count(F); -} - -bool TraceUtils::isObserveCall(CallInst *call) { - auto F = getFunctionFromCall(call); - return observeFunctions.count(F); -} diff --git a/enzyme/Enzyme/TraceUtils.h b/enzyme/Enzyme/TraceUtils.h deleted file mode 100644 index b7b437128d5b..000000000000 --- a/enzyme/Enzyme/TraceUtils.h +++ /dev/null @@ -1,157 +0,0 @@ -//===- TraceUtils.h - Utilites for interacting with traces ---------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains utilities for interacting with probabilistic programming -// traces using the probabilistic programming -// trace interface -// -//===----------------------------------------------------------------------===// - -#ifndef TraceUtils_h -#define TraceUtils_h - -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/User.h" -#include "llvm/IR/Value.h" -#include "llvm/IR/ValueMap.h" - -#include "TraceInterface.h" -#include "Utils.h" - -class TraceUtils { - -private: - llvm::Value *trace; - llvm::Value *observations; - llvm::Value *likelihood; - -public: - TraceInterface *interface; - ProbProgMode mode; - llvm::Function *newFunc; - llvm::SmallPtrSet sampleFunctions; - llvm::SmallPtrSet observeFunctions; - - constexpr static const char TraceParameterAttribute[] = "enzyme_trace"; - constexpr static const char ObservationsParameterAttribute[] = - "enzyme_observations"; - constexpr static const char LikelihoodParameterAttribute[] = - "enzyme_likelihood"; - -public: - TraceUtils(ProbProgMode mode, - const llvm::SmallPtrSetImpl &sampleFunctions, - const llvm::SmallPtrSetImpl &observeFunctions, - llvm::Function *newFunc, llvm::Argument *trace, - llvm::Argument *observations, llvm::Argument *likelihood, - TraceInterface *interface); - - static TraceUtils * - FromClone(ProbProgMode mode, - const llvm::SmallPtrSetImpl &sampleFunctions, - const llvm::SmallPtrSetImpl &observeFunctions, - TraceInterface *interface, llvm::Function *oldFunc, - llvm::ValueMap - &originalToNewFn); - - ~TraceUtils(); - -private: - static std::pair - ValueToVoidPtrAndSize(llvm::IRBuilder<> &Builder, llvm::Value *val, - llvm::Type *size_type); - -public: - TraceInterface *getTraceInterface(); - - llvm::Value *getTrace(); - - llvm::Value *getObservations(); - - llvm::Value *getLikelihood(); - - llvm::CallInst *CreateTrace(llvm::IRBuilder<> &Builder, - const llvm::Twine &Name = "trace"); - - llvm::CallInst *FreeTrace(llvm::IRBuilder<> &Builder); - - llvm::CallInst *InsertChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, - llvm::Value *score, llvm::Value *choice); - - llvm::CallInst *InsertCall(llvm::IRBuilder<> &Builder, llvm::Value *address, - llvm::Value *subtrace); - - llvm::CallInst *InsertArgument(llvm::IRBuilder<> &Builder, llvm::Value *name, - llvm::Value *argument); - - llvm::CallInst *InsertReturn(llvm::IRBuilder<> &Builder, llvm::Value *ret); - - llvm::CallInst *InsertFunction(llvm::IRBuilder<> &Builder, - llvm::Function *function); - - static llvm::CallInst * - InsertChoiceGradient(llvm::IRBuilder<> &Builder, - llvm::FunctionType *interface_type, - llvm::Value *interface_function, llvm::Value *address, - llvm::Value *choice, llvm::Value *trace); - - static llvm::CallInst * - InsertArgumentGradient(llvm::IRBuilder<> &Builder, - llvm::FunctionType *interface_type, - llvm::Value *interface_function, llvm::Value *name, - llvm::Value *argument, llvm::Value *trace); - - llvm::CallInst *GetTrace(llvm::IRBuilder<> &Builder, llvm::Value *address, - const llvm::Twine &Name = ""); - - llvm::Instruction *GetChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, - llvm::Type *choiceType, - const llvm::Twine &Name = ""); - - llvm::Instruction *HasChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, - const llvm::Twine &Name = ""); - - llvm::Instruction *HasCall(llvm::IRBuilder<> &Builder, llvm::Value *address, - const llvm::Twine &Name = ""); - - llvm::Instruction * - SampleOrCondition(llvm::IRBuilder<> &Builder, llvm::Function *sample_fn, - llvm::ArrayRef sample_args, - llvm::Value *address, const llvm::Twine &Name = ""); - - llvm::CallInst *CreateOutlinedFunction( - llvm::IRBuilder<> &Builder, - llvm::function_ref &, TraceUtils *, - llvm::ArrayRef)> - Outlined, - llvm::Type *RetTy, llvm::ArrayRef Arguments, - bool needsLikelihood = true, const llvm::Twine &Name = ""); - - bool isSampleCall(llvm::CallInst *call); - - bool isObserveCall(llvm::CallInst *call); -}; - -#endif /* TraceUtils_h */ diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h deleted file mode 100644 index 71d6e0910408..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ /dev/null @@ -1,78 +0,0 @@ -//===- BaseType.h - Category of type used in Type Analysis ------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the implementation of an enum representing the potential -// types used in Type Analysis -// -//===----------------------------------------------------------------------===// -#ifndef ENZYME_TYPE_ANALYSIS_BASE_TYPE_H -#define ENZYME_TYPE_ANALYSIS_BASE_TYPE_H 1 - -#include "llvm/ADT/StringRef.h" -#include - -/// Categories of potential types -enum class BaseType { - // integral type which doesn't represent a pointer - Integer, - // floating point - Float, - // pointer - Pointer, - // can be anything of users choosing [usually result of a constant such as 0] - Anything, - // insufficient information - Unknown -}; - -/// Convert Basetype to string -static inline std::string to_string(BaseType t) { - switch (t) { - case BaseType::Integer: - return "Integer"; - case BaseType::Float: - return "Float"; - case BaseType::Pointer: - return "Pointer"; - case BaseType::Anything: - return "Anything"; - case BaseType::Unknown: - return "Unknown"; - } - assert(0 && "unknown inttype"); - return ""; -} - -/// Convert string to BaseType -static inline BaseType parseBaseType(llvm::StringRef str) { - if (str == "Integer") - return BaseType::Integer; - if (str == "Float") - return BaseType::Float; - if (str == "Pointer") - return BaseType::Pointer; - if (str == "Anything") - return BaseType::Anything; - if (str == "Unknown") - return BaseType::Unknown; - assert(0 && "Unknown BaseType string"); - return BaseType::Unknown; -} -#endif diff --git a/enzyme/Enzyme/TypeAnalysis/ConcreteType.h b/enzyme/Enzyme/TypeAnalysis/ConcreteType.h deleted file mode 100644 index 8708fbdfcc97..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/ConcreteType.h +++ /dev/null @@ -1,518 +0,0 @@ -//===- ConcreteType.h - Underlying SubType used in Type Analysis -//------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the implementation of an a class representing all -// potential end SubTypes used in Type Analysis. This ``ConcreteType`` contains -// an the SubType category ``BaseType`` as well as the SubType of float, if -// relevant. This also contains several helper utility functions. -// -//===----------------------------------------------------------------------===// -#ifndef ENZYME_TYPE_ANALYSIS_CONCRETE_TYPE_H -#define ENZYME_TYPE_ANALYSIS_CONCRETE_TYPE_H 1 - -#include - -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/ErrorHandling.h" - -#include "BaseType.h" - -/// Concrete SubType of a given value. Consists of a category `BaseType` and the -/// particular floating point value, if relevant. -class ConcreteType { -public: - /// Category of underlying type - BaseType SubTypeEnum; - /// Floating point type, if relevant, otherwise nullptr - llvm::Type *SubType; - - /// Construct a ConcreteType from an existing FloatingPoint Type - ConcreteType(llvm::Type *SubType) - : SubTypeEnum(BaseType::Float), SubType(SubType) { - assert(SubType != nullptr); - assert(!llvm::isa(SubType)); - if (!SubType->isFloatingPointTy()) { - llvm::errs() << " passing in non FP SubType: " << *SubType << "\n"; - } - assert(SubType->isFloatingPointTy()); - } - - /// Construct a non-floating Concrete type from a BaseType - ConcreteType(BaseType SubTypeEnum) - : SubTypeEnum(SubTypeEnum), SubType(nullptr) { - assert(SubTypeEnum != BaseType::Float); - } - - /// Construct a ConcreteType from a string - /// A Concrete Type's string representation is given by the string of the - /// enum If it is a floating point it is given by Float@ - ConcreteType(llvm::StringRef Str, llvm::LLVMContext &C) { - auto Sep = Str.find('@'); - if (Sep != llvm::StringRef::npos) { - SubTypeEnum = BaseType::Float; - assert(Str.substr(0, Sep) == "Float"); - auto SubName = Str.substr(Sep + 1); - if (SubName == "half") { - SubType = llvm::Type::getHalfTy(C); - } else if (SubName == "float") { - SubType = llvm::Type::getFloatTy(C); - } else if (SubName == "double") { - SubType = llvm::Type::getDoubleTy(C); - } else if (SubName == "fp80") { - SubType = llvm::Type::getX86_FP80Ty(C); - } else if (SubName == "bf16") { - SubType = llvm::Type::getBFloatTy(C); - } else if (SubName == "fp128") { - SubType = llvm::Type::getFP128Ty(C); - } else if (SubName == "ppc128") { - SubType = llvm::Type::getPPC_FP128Ty(C); - } else { - llvm_unreachable("unknown data SubType"); - } - } else { - SubType = nullptr; - SubTypeEnum = parseBaseType(Str); - } - } - - /// Convert the ConcreteType to a string - std::string str() const { - std::string Result = to_string(SubTypeEnum); - if (SubTypeEnum == BaseType::Float) { - if (SubType->isHalfTy()) { - Result += "@half"; - } else if (SubType->isFloatTy()) { - Result += "@float"; - } else if (SubType->isDoubleTy()) { - Result += "@double"; - } else if (SubType->isX86_FP80Ty()) { - Result += "@fp80"; - } else if (SubType->isBFloatTy()) { - Result += "@bf16"; - } else if (SubType->isFP128Ty()) { - Result += "@fp128"; - } else if (SubType->isPPC_FP128Ty()) { - Result += "@ppc128"; - } else { - llvm_unreachable("unknown data SubType"); - } - } - return Result; - } - - /// Whether this ConcreteType has information (is not unknown) - bool isKnown() const { return SubTypeEnum != BaseType::Unknown; } - - /// Whether this ConcreteType must an integer - bool isIntegral() const { return SubTypeEnum == BaseType::Integer; } - - /// Whether this ConcreteType could be a pointer (SubTypeEnum is unknown or a - /// pointer) - bool isPossiblePointer() const { - return SubTypeEnum == BaseType::Pointer || - SubTypeEnum == BaseType::Anything || - SubTypeEnum == BaseType::Unknown; - } - - /// Whether this ConcreteType could be a float (SubTypeEnum is unknown or a - /// float) - bool isPossibleFloat() const { - return SubTypeEnum == BaseType::Float || - SubTypeEnum == BaseType::Anything || - SubTypeEnum == BaseType::Unknown; - } - - /// Return the floating point type, if this is a float - llvm::Type *isFloat() const { return SubType; } - - /// Return if this is known to be the BaseType BT - /// This cannot be called with BaseType::Float as it lacks information - bool operator==(const BaseType BT) const { - if (BT == BaseType::Float) { - assert(0 && - "Cannot do comparision between ConcreteType and BaseType::Float"); - llvm_unreachable( - "Cannot do comparision between ConcreteType and BaseType::Float"); - } - return SubTypeEnum == BT; - } - - /// Return if this is known not to be the BaseType BT - /// This cannot be called with BaseType::Float as it lacks information - bool operator!=(const BaseType BT) const { - if (BT == BaseType::Float) { - assert(0 && - "Cannot do comparision between ConcreteType and BaseType::Float"); - llvm_unreachable( - "Cannot do comparision between ConcreteType and BaseType::Float"); - } - return SubTypeEnum != BT; - } - - /// Return if this is known to be the ConcreteType CT - bool operator==(const ConcreteType CT) const { - return SubType == CT.SubType && SubTypeEnum == CT.SubTypeEnum; - } - - /// Return if this is known not to be the ConcreteType CT - bool operator!=(const ConcreteType CT) const { return !(*this == CT); } - - /// Set this to the given ConcreteType, returning true if - /// this ConcreteType has changed - bool operator=(const ConcreteType CT) { - bool changed = false; - if (SubTypeEnum != CT.SubTypeEnum) - changed = true; - SubTypeEnum = CT.SubTypeEnum; - if (SubType != CT.SubType) - changed = true; - SubType = CT.SubType; - return changed; - } - - /// Set this to the given BaseType, returning true if - /// this ConcreteType has changed - bool operator=(const BaseType BT) { - assert(BT != BaseType::Float); - return ConcreteType::operator=(ConcreteType(BT)); - } - - /// Set this to the logical or of itself and CT, returning whether this value - /// changed Setting `PointerIntSame` considers pointers and integers as - /// equivalent If this is an illegal operation, `LegalOr` will be set to false - bool checkedOrIn(const ConcreteType CT, bool PointerIntSame, bool &LegalOr) { - if (SubTypeEnum == BaseType::Anything) { - return false; - } - if (CT.SubTypeEnum == BaseType::Anything) { - return *this = CT; - } - if (SubTypeEnum == BaseType::Unknown) { - return *this = CT; - } - if (CT.SubTypeEnum == BaseType::Unknown) { - return false; - } - if (CT.SubTypeEnum != SubTypeEnum) { - if (PointerIntSame) { - if ((SubTypeEnum == BaseType::Pointer && - CT.SubTypeEnum == BaseType::Integer) || - (SubTypeEnum == BaseType::Integer && - CT.SubTypeEnum == BaseType::Pointer)) { - return false; - } - } - LegalOr = false; - return false; - } - assert(CT.SubTypeEnum == SubTypeEnum); - if (CT.SubType != SubType) { - LegalOr = false; - return false; - } - assert(CT.SubType == SubType); - return false; - } - - /// Set this to the logical or of itself and CT, returning whether this value - /// changed Setting `PointerIntSame` considers pointers and integers as - /// equivalent This function will error if doing an illegal Operation - bool orIn(const ConcreteType CT, bool PointerIntSame) { - bool Legal = true; - bool Result = checkedOrIn(CT, PointerIntSame, Legal); - if (!Legal) { - llvm::errs() << "Illegal orIn: " << str() << " right: " << CT.str() - << " PointerIntSame=" << PointerIntSame << "\n"; - assert(0 && "Performed illegal ConcreteType::orIn"); - llvm_unreachable("Performed illegal ConcreteType::orIn"); - } - return Result; - } - - /// Set this to the logical or of itself and CT, returning whether this value - /// changed This assumes that pointers and integers are distinct This function - /// will error if doing an illegal Operation - bool operator|=(const ConcreteType CT) { - return orIn(CT, /*pointerIntSame*/ false); - } - - /// Set this to the logical and of itself and CT, returning whether this value - /// changed If this and CT are incompatible, the result will be - /// BaseType::Unknown - bool andIn(const ConcreteType CT) { - if (SubTypeEnum == BaseType::Anything) { - return *this = CT; - } - if (CT.SubTypeEnum == BaseType::Anything) { - return false; - } - if (SubTypeEnum == BaseType::Unknown) { - return false; - } - if (CT.SubTypeEnum == BaseType::Unknown) { - return *this = CT; - } - - if (CT.SubTypeEnum != SubTypeEnum) { - return *this = BaseType::Unknown; - } - if (CT.SubType != SubType) { - return *this = BaseType::Unknown; - } - return false; - } - - /// Set this to the logical and of itself and CT, returning whether this value - /// changed If this and CT are incompatible, the result will be - /// BaseType::Unknown - bool operator&=(const ConcreteType CT) { return andIn(CT); } - - /// Keep only mappings where the type is not an `Anything` - ConcreteType PurgeAnything() const { - if (SubTypeEnum == BaseType::Anything) - return BaseType::Unknown; - return *this; - } - - /// Set this to the logical `binop` of itself and RHS, using the Binop Op, - /// returning true if this was changed. - /// This function will error on an invalid type combination - bool binopIn(bool &Legal, const ConcreteType RHS, - llvm::BinaryOperator::BinaryOps Op) { - bool Changed = false; - using namespace llvm; - - // Anything op Anything => Anything - if (SubTypeEnum == BaseType::Anything && - RHS.SubTypeEnum == BaseType::Anything) { - return Changed; - } - - // [?] op float => Unknown - if ((((SubTypeEnum == BaseType::Anything || - SubTypeEnum == BaseType::Integer || - SubTypeEnum == BaseType::Unknown) && - RHS.isFloat()) || - (isFloat() && (RHS.SubTypeEnum == BaseType::Anything || - RHS.SubTypeEnum == BaseType::Integer || - RHS.SubTypeEnum == BaseType::Unknown)))) { - SubTypeEnum = BaseType::Unknown; - SubType = nullptr; - Changed = true; - return Changed; - } - - // Unknown op Anything => Unknown - if ((SubTypeEnum == BaseType::Unknown && - RHS.SubTypeEnum == BaseType::Anything) || - (SubTypeEnum == BaseType::Anything && - RHS.SubTypeEnum == BaseType::Unknown)) { - if (SubTypeEnum != BaseType::Unknown) { - SubTypeEnum = BaseType::Unknown; - Changed = true; - } - return Changed; - } - - // Integer op Integer => Integer - if (SubTypeEnum == BaseType::Integer && - RHS.SubTypeEnum == BaseType::Integer) { - return Changed; - } - - // Integer op Anything => {Anything, Integer} - if ((SubTypeEnum == BaseType::Anything && - RHS.SubTypeEnum == BaseType::Integer) || - (SubTypeEnum == BaseType::Integer && - RHS.SubTypeEnum == BaseType::Anything)) { - - switch (Op) { - // The result of these operands mix data between LHS/RHS - // Therefore there is some "anything" data in the result - case BinaryOperator::Add: - case BinaryOperator::Sub: - case BinaryOperator::Mul: - case BinaryOperator::And: - case BinaryOperator::Or: - case BinaryOperator::Xor: - if (SubTypeEnum != BaseType::Anything) { - SubTypeEnum = BaseType::Anything; - Changed = true; - } - break; - - // The result of these operands only use data from LHS - case BinaryOperator::UDiv: - case BinaryOperator::SDiv: - case BinaryOperator::URem: - case BinaryOperator::SRem: - case BinaryOperator::Shl: - case BinaryOperator::AShr: - case BinaryOperator::LShr: - // No change since we retain data from LHS - break; - default: - Legal = false; - return Changed; - } - return Changed; - } - - // Integer op Unknown => Unknown - // e.g. pointer + int = pointer and int + int = int - if ((SubTypeEnum == BaseType::Unknown && - RHS.SubTypeEnum == BaseType::Integer) || - (SubTypeEnum == BaseType::Integer && - RHS.SubTypeEnum == BaseType::Unknown)) { - if (SubTypeEnum != BaseType::Unknown) { - SubTypeEnum = BaseType::Unknown; - Changed = true; - } - return Changed; - } - - // Pointer op Pointer => {Integer, Illegal} - if (SubTypeEnum == BaseType::Pointer && - RHS.SubTypeEnum == BaseType::Pointer) { - switch (Op) { - case BinaryOperator::Sub: - SubTypeEnum = BaseType::Integer; - Changed = true; - break; - case BinaryOperator::Add: - case BinaryOperator::Mul: - case BinaryOperator::UDiv: - case BinaryOperator::SDiv: - case BinaryOperator::URem: - case BinaryOperator::SRem: - case BinaryOperator::And: - case BinaryOperator::Or: - case BinaryOperator::Xor: - case BinaryOperator::Shl: - case BinaryOperator::AShr: - case BinaryOperator::LShr: - default: - Legal = false; - return Changed; - } - return Changed; - } - - // Pointer - Unknown => Unknown - // This is because Pointer - Pointer => Integer - // and Pointer - Integer => Pointer - if (Op == BinaryOperator::Sub && SubTypeEnum == BaseType::Pointer && - RHS.SubTypeEnum == BaseType::Unknown) { - SubTypeEnum = BaseType::Unknown; - Changed = true; - return Changed; - } - - // Pointer op ? => {Pointer, Unknown} - if ((SubTypeEnum == BaseType::Integer && - RHS.SubTypeEnum == BaseType::Pointer) || - (SubTypeEnum == BaseType::Pointer && - RHS.SubTypeEnum == BaseType::Integer) || - (SubTypeEnum == BaseType::Integer && - RHS.SubTypeEnum == BaseType::Pointer) || - (SubTypeEnum == BaseType::Pointer && - RHS.SubTypeEnum == BaseType::Unknown) || - (SubTypeEnum == BaseType::Unknown && - RHS.SubTypeEnum == BaseType::Pointer) || - (SubTypeEnum == BaseType::Pointer && - RHS.SubTypeEnum == BaseType::Anything) || - (SubTypeEnum == BaseType::Anything && - RHS.SubTypeEnum == BaseType::Pointer)) { - - switch (Op) { - case BinaryOperator::Sub: - if (SubTypeEnum == BaseType::Anything || - RHS.SubTypeEnum == BaseType::Anything) { - if (SubTypeEnum != BaseType::Unknown) { - SubTypeEnum = BaseType::Unknown; - Changed = true; - } - break; - } - if (RHS.SubTypeEnum == BaseType::Pointer) { - if (SubTypeEnum != BaseType::Unknown) { - SubTypeEnum = BaseType::Unknown; - Changed = true; - } - break; - } - [[fallthrough]]; - case BinaryOperator::Add: - case BinaryOperator::Mul: - if (SubTypeEnum != BaseType::Pointer) { - SubTypeEnum = BaseType::Pointer; - Changed = true; - } - break; - case BinaryOperator::UDiv: - case BinaryOperator::SDiv: - case BinaryOperator::URem: - case BinaryOperator::SRem: - if (RHS.SubTypeEnum == BaseType::Pointer) { - Legal = false; - return Changed; - } else if (SubTypeEnum != BaseType::Unknown) { - SubTypeEnum = BaseType::Unknown; - Changed = true; - } - break; - case BinaryOperator::And: - case BinaryOperator::Or: - case BinaryOperator::Xor: - case BinaryOperator::Shl: - case BinaryOperator::AShr: - case BinaryOperator::LShr: - if (SubTypeEnum != BaseType::Unknown) { - SubTypeEnum = BaseType::Unknown; - Changed = true; - } - break; - default: - Legal = false; - return Changed; - } - return Changed; - } - - Legal = false; - return Changed; - } - - /// Compare concrete types for use in map's - bool operator<(const ConcreteType dt) const { - if (SubTypeEnum == dt.SubTypeEnum) { - return SubType < dt.SubType; - } else { - return SubTypeEnum < dt.SubTypeEnum; - } - } -}; - -// Convert ConcreteType to string -static inline std::string to_string(const ConcreteType dt) { return dt.str(); } - -#endif diff --git a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp b/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp deleted file mode 100644 index 2376c9b23353..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp +++ /dev/null @@ -1,181 +0,0 @@ -//===- RustDebugInfo.cpp - Implementaion of Rust Debug Info Parser ---===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===-------------------------------------------------------------------===// -// -// This file implement the Rust debug info parsing function. It will get the -// description of types from debug info of an instruction and pass it to -// concrete functions according to the kind of a description and construct -// the type tree recursively. -// -//===-------------------------------------------------------------------===// -#include "llvm/IR/DIBuilder.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" -#include "llvm/Support/CommandLine.h" - -#include "RustDebugInfo.h" - -using namespace llvm; - -TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL); - -TypeTree parseDIType(DIBasicType &Type, Instruction &I, DataLayout &DL) { - auto TypeName = Type.getName(); - TypeTree Result; - if (TypeName == "f64") { - Result = TypeTree(Type::getDoubleTy(I.getContext())).Only(0, &I); - } else if (TypeName == "f32") { - Result = TypeTree(Type::getFloatTy(I.getContext())).Only(0, &I); - } else if (TypeName == "i8" || TypeName == "i16" || TypeName == "i32" || - TypeName == "i64" || TypeName == "isize" || TypeName == "u8" || - TypeName == "u16" || TypeName == "u32" || TypeName == "u64" || - TypeName == "usize" || TypeName == "i128" || TypeName == "u128") { - Result = TypeTree(ConcreteType(BaseType::Integer)).Only(0, &I); - } else { - Result = TypeTree(ConcreteType(BaseType::Unknown)).Only(0, &I); - } - return Result; -} - -TypeTree parseDIType(DICompositeType &Type, Instruction &I, DataLayout &DL) { - TypeTree Result; - if (Type.getTag() == dwarf::DW_TAG_array_type) { - DIType *SubType = Type.getBaseType(); - TypeTree SubTT = parseDIType(*SubType, I, DL); - size_t Align = Type.getAlignInBytes(); - size_t SubSize = SubType->getSizeInBits() / 8; - size_t Size = Type.getSizeInBits() / 8; - DINodeArray Subranges = Type.getElements(); - size_t pos = 0; - for (auto r : Subranges) { - DISubrange *Subrange = dyn_cast(r); - if (auto Count = Subrange->getCount().get()) { - int64_t count = Count->getSExtValue(); - if (count == -1) { - break; - } - for (int64_t i = 0; i < count; i++) { - Result |= SubTT.ShiftIndices(DL, 0, Size, pos); - size_t tmp = pos + SubSize; - if (tmp % Align != 0) { - pos = (tmp / Align + 1) * Align; - } else { - pos = tmp; - } - } - } else { - assert(0 && "There shouldn't be non-constant-size arrays in Rust"); - } - } - } else if (Type.getTag() == dwarf::DW_TAG_structure_type || - Type.getTag() == dwarf::DW_TAG_union_type) { - DINodeArray Elements = Type.getElements(); - size_t Size = Type.getSizeInBits() / 8; - bool firstSubTT = true; - for (auto e : Elements) { - DIType *SubType = dyn_cast(e); - assert(SubType->getTag() == dwarf::DW_TAG_member); - TypeTree SubTT = parseDIType(*SubType, I, DL); - size_t Offset = SubType->getOffsetInBits() / 8; - SubTT = SubTT.ShiftIndices(DL, 0, Size, Offset); - if (Type.getTag() == dwarf::DW_TAG_structure_type) { - Result |= SubTT; - } else { - if (firstSubTT) { - Result = SubTT; - } else { - Result &= SubTT; - } - } - if (firstSubTT) { - firstSubTT = !firstSubTT; - } - } - } else { - assert(0 && "Composite types other than arrays, structs and unions are not " - "supported by Rust debug info parser"); - } - return Result; -} - -TypeTree parseDIType(DIDerivedType &Type, Instruction &I, DataLayout &DL) { - if (Type.getTag() == dwarf::DW_TAG_pointer_type) { - TypeTree Result(BaseType::Pointer); - DIType *SubType = Type.getBaseType(); - TypeTree SubTT = parseDIType(*SubType, I, DL); - if (isa(SubType)) { - Result |= SubTT.ShiftIndices(DL, 0, 1, -1); - } else { - Result |= SubTT; - } - return Result.Only(0, &I); - } else if (Type.getTag() == dwarf::DW_TAG_member) { - DIType *SubType = Type.getBaseType(); - TypeTree Result = parseDIType(*SubType, I, DL); - return Result; - } else { - assert(0 && "Derived types other than pointers and members are not " - "supported by Rust debug info parser"); - } - return {}; -} - -TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL) { - if (Type.getSizeInBits() == 0) { - return TypeTree(); - } - - if (auto BT = dyn_cast(&Type)) { - return parseDIType(*BT, I, DL); - } else if (auto CT = dyn_cast(&Type)) { - return parseDIType(*CT, I, DL); - } else if (auto DT = dyn_cast(&Type)) { - return parseDIType(*DT, I, DL); - } else { - assert(0 && "Types other than floating-points, integers, arrays, pointers, " - "slices, and structs are not supported by debug info parser"); - } - return {}; -} - -bool isU8PointerType(DIType &type) { - if (type.getTag() == dwarf::DW_TAG_pointer_type) { - auto PTy = dyn_cast(&type); - DIType *SubType = PTy->getBaseType(); - if (auto BTy = dyn_cast(SubType)) { - std::string name = BTy->getName().str(); - if (name == "u8") { - return true; - } - } - } - return false; -} - -TypeTree parseDIType(DbgDeclareInst &I, DataLayout &DL) { - DIType *type = I.getVariable()->getType(); - - // If the type is *u8, do nothing, since the underlying type of data pointed - // by a *u8 can be anything - if (isU8PointerType(*type)) { - return TypeTree(); - } - TypeTree Result = parseDIType(*type, I, DL); - return Result; -} diff --git a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.h b/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.h deleted file mode 100644 index d579b6e02d60..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.h +++ /dev/null @@ -1,38 +0,0 @@ -//===- RustDebugInfo.h - Declaration of Rust Debug Info Parser -------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===-------------------------------------------------------------------===// -// -// This file contains the declaration of the Rust debug info parsing function -// which parses the debug info appended to LLVM IR generated by rustc and -// extracts useful type info from it. The type info will be used to initialize -// the following type analysis. -// -//===-------------------------------------------------------------------===// -#ifndef ENZYME_RUSTDEBUGINFO_H -#define ENZYME_RUSTDEBUGINFO_H 1 - -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" - -#include "TypeTree.h" - -/// Construct the type tree from debug info of an instruction -TypeTree parseDIType(llvm::DbgDeclareInst &I, llvm::DataLayout &DL); - -#endif // ENZYME_RUSTDEBUGINFO_H diff --git a/enzyme/Enzyme/TypeAnalysis/TBAA.h b/enzyme/Enzyme/TypeAnalysis/TBAA.h deleted file mode 100644 index d476d4715b0e..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/TBAA.h +++ /dev/null @@ -1,519 +0,0 @@ -//===- TBAA.h - Helpers for llvm::Type-based alias analysis ------------===// -// -// Enzyme Project and The LLVM Project -// First section modified from: TypeBasedAliasAnalysis.cpp in LLVM -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the implementation of several utilities for understanding -// TBAA metadata and converting that metadata into corresponding TypeAnalysis -// representations. -// -//===----------------------------------------------------------------------===// -#ifndef ENZYME_TYPE_ANALYSIS_TBAA_H -#define ENZYME_TYPE_ANALYSIS_TBAA_H 1 - -#include "BaseType.h" -#include "ConcreteType.h" -#include "TypeTree.h" - -/// isNewFormatTypeNode - Return true iff the given type node is in the new -/// size-aware format. -static bool isNewFormatTypeNode(const llvm::MDNode *N) { - if (N->getNumOperands() < 3) - return false; - // In the old format the first operand is a string. - if (!llvm::isa(N->getOperand(0))) - return false; - return true; -} - -/// This is a simple wrapper around an llvm::MDNode which provides a -/// higher-level interface by hiding the details of how alias analysis -/// information is encoded in its operands. -template class TBAANodeImpl { - MDNodeTy *Node = nullptr; - -public: - TBAANodeImpl() = default; - explicit TBAANodeImpl(MDNodeTy *N) : Node(N) {} - - /// getNode - Get the llvm::MDNode for this TBAANode. - MDNodeTy *getNode() const { return Node; } - - /// isNewFormat - Return true iff the wrapped type node is in the new - /// size-aware format. - bool isNewFormat() const { return isNewFormatTypeNode(Node); } - - /// getParent - Get this TBAANode's Alias tree parent. - TBAANodeImpl getParent() const { - if (isNewFormat()) - return TBAANodeImpl(llvm::cast(Node->getOperand(0))); - - if (Node->getNumOperands() < 2) - return TBAANodeImpl(); - MDNodeTy *P = llvm::dyn_cast_or_null(Node->getOperand(1)); - if (!P) - return TBAANodeImpl(); - // Ok, this node has a valid parent. Return it. - return TBAANodeImpl(P); - } - - /// Test if this TBAANode represents a type for objects which are - /// not modified (by any means) in the context where this - /// AliasAnalysis is relevant. - bool isTypeImmutable() const { - if (Node->getNumOperands() < 3) - return false; - llvm::ConstantInt *CI = - llvm::mdconst::dyn_extract(Node->getOperand(2)); - if (!CI) - return false; - return CI->getValue()[0]; - } -}; - -/// \name Specializations of \c TBAANodeImpl for const and non const qualified -/// \c MDNode. -/// @{ -using TBAANode = TBAANodeImpl; -using MutableTBAANode = TBAANodeImpl; -/// @} - -/// This is a simple wrapper around an llvm::MDNode which provides a -/// higher-level interface by hiding the details of how alias analysis -/// information is encoded in its operands. -template class TBAAStructTagNodeImpl { - /// This node should be created with createTBAAAccessTag(). - MDNodeTy *Node; - -public: - explicit TBAAStructTagNodeImpl(MDNodeTy *N) : Node(N) {} - - /// Get the llvm::MDNode for this TBAAStructTagNode. - MDNodeTy *getNode() const { return Node; } - - /// isNewFormat - Return true iff the wrapped access tag is in the new - /// size-aware format. - bool isNewFormat() const { - if (Node->getNumOperands() < 4) - return false; - if (MDNodeTy *AccessType = getAccessType()) - if (!TBAANodeImpl(AccessType).isNewFormat()) - return false; - return true; - } - - MDNodeTy *getBaseType() const { - return llvm::dyn_cast_or_null(Node->getOperand(0)); - } - - MDNodeTy *getAccessType() const { - return llvm::dyn_cast_or_null(Node->getOperand(1)); - } - - uint64_t getOffset() const { - return llvm::mdconst::extract(Node->getOperand(2)) - ->getZExtValue(); - } - - uint64_t getSize() const { - if (!isNewFormat()) - return UINT64_MAX; - return llvm::mdconst::extract(Node->getOperand(3)) - ->getZExtValue(); - } - - /// Test if this TBAAStructTagNode represents a type for objects - /// which are not modified (by any means) in the context where this - /// AliasAnalysis is relevant. - bool isTypeImmutable() const { - unsigned OpNo = isNewFormat() ? 4 : 3; - if (Node->getNumOperands() < OpNo + 1) - return false; - llvm::ConstantInt *CI = - llvm::mdconst::dyn_extract(Node->getOperand(OpNo)); - if (!CI) - return false; - return CI->getValue()[0]; - } -}; - -/// \name Specializations of \c TBAAStructTagNodeImpl for const and non const -/// qualified \c MDNods. -/// @{ -using TBAAStructTagNode = TBAAStructTagNodeImpl; -using MutableTBAAStructTagNode = TBAAStructTagNodeImpl; -/// @} - -/// This is a simple wrapper around an llvm::MDNode which provides a -/// higher-level interface by hiding the details of how alias analysis -/// information is encoded in its operands. -class TBAAStructTypeNode { - /// This node should be created with createTBAATypeNode(). - const llvm::MDNode *Node = nullptr; - -public: - TBAAStructTypeNode() = default; - explicit TBAAStructTypeNode(const llvm::MDNode *N) : Node(N) {} - - /// Get the llvm::MDNode for this TBAAStructTypeNode. - const llvm::MDNode *getNode() const { return Node; } - - /// isNewFormat - Return true iff the wrapped type node is in the new - /// size-aware format. - bool isNewFormat() const { return isNewFormatTypeNode(Node); } - - bool operator==(const TBAAStructTypeNode &Other) const { - return getNode() == Other.getNode(); - } - - /// getId - Return type identifier. - llvm::Metadata *getId() const { - return Node->getOperand(isNewFormat() ? 2 : 0); - } - - unsigned getNumFields() const { - unsigned FirstFieldOpNo = isNewFormat() ? 3 : 1; - unsigned NumOpsPerField = isNewFormat() ? 3 : 2; - return (getNode()->getNumOperands() - FirstFieldOpNo) / NumOpsPerField; - } - - uint64_t getFieldOffset(unsigned FieldIndex) const { - unsigned FirstFieldOpNo = isNewFormat() ? 3 : 1; - unsigned NumOpsPerField = isNewFormat() ? 3 : 2; - unsigned OpIndex = FirstFieldOpNo + FieldIndex * NumOpsPerField; - - uint64_t Cur = - llvm::mdconst::extract(Node->getOperand(OpIndex + 1)) - ->getZExtValue(); - return Cur; - } - - TBAAStructTypeNode getFieldType(unsigned FieldIndex) const { - unsigned FirstFieldOpNo = isNewFormat() ? 3 : 1; - unsigned NumOpsPerField = isNewFormat() ? 3 : 2; - unsigned OpIndex = FirstFieldOpNo + FieldIndex * NumOpsPerField; - auto *TypeNode = llvm::cast(getNode()->getOperand(OpIndex)); - return TBAAStructTypeNode(TypeNode); - } - - /// Get this TBAAStructTypeNode's field in the type DAG with - /// given offset. Update the offset to be relative to the field type. - TBAAStructTypeNode getField(uint64_t &Offset) const { - bool NewFormat = isNewFormat(); - if (NewFormat) { - // New-format root and scalar type nodes have no fields. - if (Node->getNumOperands() < 6) - return TBAAStructTypeNode(); - } else { - // Parent can be omitted for the root node. - if (Node->getNumOperands() < 2) - return TBAAStructTypeNode(); - - // Fast path for a scalar type node and a struct type node with a single - // field. - if (Node->getNumOperands() <= 3) { - uint64_t Cur = - Node->getNumOperands() == 2 - ? 0 - : llvm::mdconst::extract(Node->getOperand(2)) - ->getZExtValue(); - Offset -= Cur; - llvm::MDNode *P = - llvm::dyn_cast_or_null(Node->getOperand(1)); - if (!P) - return TBAAStructTypeNode(); - return TBAAStructTypeNode(P); - } - } - - // Assume the offsets are in order. We return the previous field if - // the current offset is bigger than the given offset. - unsigned FirstFieldOpNo = NewFormat ? 3 : 1; - unsigned NumOpsPerField = NewFormat ? 3 : 2; - unsigned TheIdx = 0; - for (unsigned Idx = FirstFieldOpNo; Idx < Node->getNumOperands(); - Idx += NumOpsPerField) { - uint64_t Cur = - llvm::mdconst::extract(Node->getOperand(Idx + 1)) - ->getZExtValue(); - if (Cur > Offset) { - assert(Idx >= FirstFieldOpNo + NumOpsPerField && - "TBAAStructTypeNode::getField should have an offset match!"); - TheIdx = Idx - NumOpsPerField; - break; - } - } - // Move along the last field. - if (TheIdx == 0) - TheIdx = Node->getNumOperands() - NumOpsPerField; - uint64_t Cur = - llvm::mdconst::extract(Node->getOperand(TheIdx + 1)) - ->getZExtValue(); - Offset -= Cur; - llvm::MDNode *P = - llvm::dyn_cast_or_null(Node->getOperand(TheIdx)); - if (!P) - return TBAAStructTypeNode(); - return TBAAStructTypeNode(P); - } -}; - -/// Check the first operand of the tbaa tag node, if it is a llvm::MDNode, we -/// treat it as struct-path aware TBAA format, otherwise, we treat it as scalar -/// TBAA format. -static inline bool isStructPathTBAA(const llvm::MDNode *MD) { - // Anonymous TBAA root starts with a llvm::MDNode and dragonegg uses it as - // a TBAA tag. - return llvm::isa(MD->getOperand(0)) && - MD->getNumOperands() >= 3; -} - -static inline const llvm::MDNode * -createAccessTag(const llvm::MDNode *AccessType) { - // If there is no access type or the access type is the root node, then - // we don't have any useful access tag to return. - if (!AccessType || AccessType->getNumOperands() < 2) - return nullptr; - - llvm::Type *Int64 = llvm::IntegerType::get(AccessType->getContext(), 64); - auto *OffsetNode = - llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(Int64, 0)); - - if (TBAAStructTypeNode(AccessType).isNewFormat()) { - // TODO: Take access ranges into account when matching access tags and - // fix this code to generate actual access sizes for generic tags. - uint64_t AccessSize = UINT64_MAX; - auto *SizeNode = llvm::ConstantAsMetadata::get( - llvm::ConstantInt::get(Int64, AccessSize)); - llvm::Metadata *Ops[] = {const_cast(AccessType), - const_cast(AccessType), OffsetNode, - SizeNode}; - return llvm::MDNode::get(AccessType->getContext(), Ops); - } - - llvm::Metadata *Ops[] = {const_cast(AccessType), - const_cast(AccessType), OffsetNode}; - return llvm::MDNode::get(AccessType->getContext(), Ops); -} - -// Modified from llvm::MDNode::isTBAAVtableAccess() - -static inline std::string -getAccessNameTBAA(const llvm::MDNode *M, - const std::set &legalnames) { - if (!isStructPathTBAA(M)) { - if (M->getNumOperands() < 1) - return ""; - if (const llvm::MDString *Tag1 = - llvm::dyn_cast(M->getOperand(0))) { - return Tag1->getString().str(); - } - return ""; - } - - // For struct-path aware TBAA, we use the access type of the tag. - // llvm::errs() << "M: " << *M << "\n"; - TBAAStructTagNode Tag(M); - // llvm::errs() << "AT: " << *Tag.getAccessType() << "\n"; - TBAAStructTypeNode AccessType(Tag.getAccessType()); - - // llvm::errs() << "numfields: " << AccessType.getNumFields() << "\n"; - while (AccessType.getNumFields() > 0) { - - if (auto *Id = llvm::dyn_cast(AccessType.getId())) { - // llvm::errs() << "cur access type: " << Id->getString() << "\n"; - if (legalnames.count(Id->getString().str())) { - return Id->getString().str(); - } - } - - AccessType = AccessType.getFieldType(0); - // llvm::errs() << "numfields: " << AccessType.getNumFields() << "\n"; - } - - if (auto *Id = llvm::dyn_cast(AccessType.getId())) { - // llvm::errs() << "access type: " << Id->getString() << "\n"; - return Id->getString().str(); - } - return ""; -} - -static inline std::string -getAccessNameTBAA(llvm::Instruction *Inst, - const std::set &legalnames) { - if (const llvm::MDNode *M = - Inst->getMetadata(llvm::LLVMContext::MD_tbaa_struct)) { - for (unsigned i = 2; i < M->getNumOperands(); i += 3) { - if (const llvm::MDNode *M2 = - llvm::dyn_cast(M->getOperand(i))) { - auto res = getAccessNameTBAA(M2, legalnames); - if (res != "") - return res; - } - } - } - if (const llvm::MDNode *M = Inst->getMetadata(llvm::LLVMContext::MD_tbaa)) { - return getAccessNameTBAA(M, legalnames); - } - return ""; -} - -//! The following is not taken from LLVM - -extern "C" { -/// Flag to print llvm::Type Analysis results as they are derived -extern llvm::cl::opt EnzymePrintType; -} - -/// Derive the ConcreteType corresponding to the string TypeName -/// The llvm::Instruction I denotes the context in which this was found -static inline ConcreteType -getTypeFromTBAAString(std::string TypeName, llvm::Instruction &I, - std::shared_ptr MST) { - if (TypeName == "long long" || TypeName == "long" || TypeName == "int" || - TypeName == "bool" || TypeName == "jtbaa_arraysize" || - TypeName == "jtbaa_arraylen") { - if (EnzymePrintType) { - llvm::errs() << "known tbaa "; - if (MST) - I.print(llvm::errs(), *MST); - else - llvm::errs() << I; - llvm::errs() << " " << TypeName << "\n"; - } - return ConcreteType(BaseType::Integer); - } else if (TypeName == "any pointer" || TypeName == "vtable pointer" || - TypeName == "jtbaa_arrayptr" || TypeName == "jtbaa_tag") { - if (EnzymePrintType) { - llvm::errs() << "known tbaa "; - if (MST) - I.print(llvm::errs(), *MST); - else - llvm::errs() << I; - llvm::errs() << " " << TypeName << "\n"; - } - return ConcreteType(BaseType::Pointer); - } else if (TypeName == "float") { - if (EnzymePrintType) { - llvm::errs() << "known tbaa "; - if (MST) - I.print(llvm::errs(), *MST); - else - llvm::errs() << I; - llvm::errs() << " " << TypeName << "\n"; - } - return llvm::Type::getFloatTy(I.getContext()); - } else if (TypeName == "double") { - if (EnzymePrintType) { - llvm::errs() << "known tbaa "; - if (MST) - I.print(llvm::errs(), *MST); - else - llvm::errs() << I; - llvm::errs() << " " << TypeName << "\n"; - } - return llvm::Type::getDoubleTy(I.getContext()); - } - return ConcreteType(BaseType::Unknown); -} - -/// Given a TBAA access node return the corresponding TypeTree -/// This includes recursively parsing the access nodes, with -/// corresponding offsets in the result -static inline TypeTree parseTBAA(TBAAStructTypeNode AccessType, - llvm::Instruction &I, - const llvm::DataLayout &DL, - std::shared_ptr MST) { - - if (auto *Id = llvm::dyn_cast(AccessType.getId())) { - auto CT = getTypeFromTBAAString(Id->getString().str(), I, MST); - if (CT.isKnown()) { - return TypeTree(CT).Only(-1, &I); - } - } - - TypeTree Result(BaseType::Pointer); - for (unsigned i = 0, size = AccessType.getNumFields(); i < size; ++i) { - auto SubAccess = AccessType.getFieldType(i); - auto Offset = AccessType.getFieldOffset(i); - auto SubResult = parseTBAA(SubAccess, I, DL, MST); - Result |= SubResult.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1, - /*addOffset*/ Offset); - } - - return Result; -} - -/// Given a TBAA metadata node return the corresponding TypeTree -/// Modified from llvm::MDNode::isTBAAVtableAccess() -static inline TypeTree parseTBAA(const llvm::MDNode *M, llvm::Instruction &I, - const llvm::DataLayout &DL, - std::shared_ptr MST) { - if (!isStructPathTBAA(M)) { - if (M->getNumOperands() < 1) - return TypeTree(); - if (const llvm::MDString *Tag1 = - llvm::dyn_cast(M->getOperand(0))) { - return TypeTree(getTypeFromTBAAString(Tag1->getString().str(), I, MST)) - .Only(0, &I); - } - return TypeTree(); - } - - // For struct-path aware TBAA, we use the access type of the tag. - TBAAStructTagNode Tag(M); - TBAAStructTypeNode AccessType(Tag.getAccessType()); - return parseTBAA(AccessType, I, DL, MST); -} - -/// Given an llvm::Instruction, return a TypeTree representing any -/// types that can be derived from TBAA metadata attached -static inline TypeTree parseTBAA(llvm::Instruction &I, - const llvm::DataLayout &DL, - std::shared_ptr MST) { - TypeTree Result; - if (const llvm::MDNode *M = - I.getMetadata(llvm::LLVMContext::MD_tbaa_struct)) { - for (unsigned i = 0, size = M->getNumOperands(); i < size; i += 3) { - if (const llvm::MDNode *M2 = - llvm::dyn_cast(M->getOperand(i + 2))) { - auto SubResult = parseTBAA(M2, I, DL, MST); - auto Start = llvm::cast( - llvm::cast(M->getOperand(i)) - ->getValue()) - ->getLimitedValue(); - auto Len = - llvm::cast( - llvm::cast(M->getOperand(i + 1)) - ->getValue()) - ->getLimitedValue(); - Result |= - SubResult.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ Len, - /*add offset*/ Start); - } - } - } - if (const llvm::MDNode *M = I.getMetadata(llvm::LLVMContext::MD_tbaa)) { - Result |= parseTBAA(M, I, DL, MST); - } - Result |= TypeTree(BaseType::Pointer); - return Result; -} - -#endif diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp deleted file mode 100644 index 85050315a87d..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ /dev/null @@ -1,6467 +0,0 @@ -//===- TypeAnalysis.cpp - Implementation of Type Analysis ------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the implementation of Type Analysis, a utility for -// computing the underlying data type of LLVM values. -// -//===----------------------------------------------------------------------===// -#include -#include - -#include - -#include "llvm/Demangle/Demangle.h" -#include "llvm/Demangle/ItaniumDemangle.h" - -#include "llvm/IR/Constants.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/ModuleSlotTracker.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -#include "llvm/IR/InstIterator.h" - -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/TimeProfiler.h" -#include "llvm/Support/raw_ostream.h" - -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringSet.h" - -#include "llvm/IR/InlineAsm.h" - -#include "../Utils.h" -#include "TypeAnalysis.h" - -#include "../FunctionUtils.h" -#include "../LibraryFuncs.h" - -#include "RustDebugInfo.h" -#include "TBAA.h" - -#include - -#if LLVM_VERSION_MAJOR >= 14 -#define getAttribute getAttributeAtIndex -#define hasAttribute hasAttributeAtIndex -#define addAttribute addAttributeAtIndex -#endif - -using namespace llvm; - -extern "C" { -/// Maximum offset for type trees to keep -llvm::cl::opt MaxIntOffset("enzyme-max-int-offset", cl::init(100), - cl::Hidden, - cl::desc("Maximum type tree offset")); - -llvm::cl::opt EnzymeMaxTypeDepth("enzyme-max-type-depth", cl::init(6), - cl::Hidden, - cl::desc("Maximum type tree depth")); - -llvm::cl::opt EnzymePrintType("enzyme-print-type", cl::init(false), - cl::Hidden, - cl::desc("Print type analysis algorithm")); - -llvm::cl::opt RustTypeRules("enzyme-rust-type", cl::init(false), - cl::Hidden, - cl::desc("Enable rust-specific type rules")); - -llvm::cl::opt EnzymeStrictAliasing( - "enzyme-strict-aliasing", cl::init(true), cl::Hidden, - cl::desc("Assume strict aliasing of types / type stability")); -} - -const llvm::StringMap LIBM_FUNCTIONS = { - {"sinc", Intrinsic::not_intrinsic}, - {"sincn", Intrinsic::not_intrinsic}, - {"cos", Intrinsic::cos}, - {"sin", Intrinsic::sin}, - {"tan", Intrinsic::not_intrinsic}, - {"acos", Intrinsic::not_intrinsic}, - {"__nv_frcp_rd", Intrinsic::not_intrinsic}, - {"__nv_frcp_rn", Intrinsic::not_intrinsic}, - {"__nv_frcp_ru", Intrinsic::not_intrinsic}, - {"__nv_frcp_rz", Intrinsic::not_intrinsic}, - {"__nv_drcp_rd", Intrinsic::not_intrinsic}, - {"__nv_drcp_rn", Intrinsic::not_intrinsic}, - {"__nv_drcp_ru", Intrinsic::not_intrinsic}, - {"__nv_drcp_rz", Intrinsic::not_intrinsic}, - {"asin", Intrinsic::not_intrinsic}, - {"__nv_asin", Intrinsic::not_intrinsic}, - {"atan", Intrinsic::not_intrinsic}, - {"atan2", Intrinsic::not_intrinsic}, - {"__nv_atan2", Intrinsic::not_intrinsic}, -#if LLVM_VERSION_MAJOR >= 19 - {"cosh", Intrinsic::cosh}, - {"sinh", Intrinsic::sinh}, - {"tanh", Intrinsic::tanh}, -#else - {"cosh", Intrinsic::not_intrinsic}, - {"sinh", Intrinsic::not_intrinsic}, - {"tanh", Intrinsic::not_intrinsic}, -#endif - {"acosh", Intrinsic::not_intrinsic}, - {"asinh", Intrinsic::not_intrinsic}, - {"atanh", Intrinsic::not_intrinsic}, - {"exp", Intrinsic::exp}, - {"exp2", Intrinsic::exp2}, - {"exp10", Intrinsic::not_intrinsic}, - {"log", Intrinsic::log}, - {"log10", Intrinsic::log10}, - {"expm1", Intrinsic::not_intrinsic}, - {"log1p", Intrinsic::not_intrinsic}, - {"log2", Intrinsic::log2}, - {"logb", Intrinsic::not_intrinsic}, - {"pow", Intrinsic::pow}, - {"sqrt", Intrinsic::sqrt}, - {"cbrt", Intrinsic::not_intrinsic}, - {"hypot", Intrinsic::not_intrinsic}, - - {"__mulsc3", Intrinsic::not_intrinsic}, - {"__muldc3", Intrinsic::not_intrinsic}, - {"__multc3", Intrinsic::not_intrinsic}, - {"__mulxc3", Intrinsic::not_intrinsic}, - - {"__divsc3", Intrinsic::not_intrinsic}, - {"__divdc3", Intrinsic::not_intrinsic}, - {"__divtc3", Intrinsic::not_intrinsic}, - {"__divxc3", Intrinsic::not_intrinsic}, - - {"Faddeeva_erf", Intrinsic::not_intrinsic}, - {"Faddeeva_erfc", Intrinsic::not_intrinsic}, - {"Faddeeva_erfcx", Intrinsic::not_intrinsic}, - {"Faddeeva_erfi", Intrinsic::not_intrinsic}, - {"Faddeeva_dawson", Intrinsic::not_intrinsic}, - {"Faddeeva_erf_re", Intrinsic::not_intrinsic}, - {"Faddeeva_erfc_re", Intrinsic::not_intrinsic}, - {"Faddeeva_erfcx_re", Intrinsic::not_intrinsic}, - {"Faddeeva_erfi_re", Intrinsic::not_intrinsic}, - {"Faddeeva_dawson_re", Intrinsic::not_intrinsic}, - {"erf", Intrinsic::not_intrinsic}, - {"erfi", Intrinsic::not_intrinsic}, - {"erfc", Intrinsic::not_intrinsic}, - {"erfinv", Intrinsic::not_intrinsic}, - - {"__fd_sincos_1", Intrinsic::not_intrinsic}, - {"sincospi", Intrinsic::not_intrinsic}, - {"cmplx_inv", Intrinsic::not_intrinsic}, - - // bessel functions - {"j0", Intrinsic::not_intrinsic}, - {"j1", Intrinsic::not_intrinsic}, - {"jn", Intrinsic::not_intrinsic}, - {"y0", Intrinsic::not_intrinsic}, - {"y1", Intrinsic::not_intrinsic}, - {"yn", Intrinsic::not_intrinsic}, - {"tgamma", Intrinsic::not_intrinsic}, - {"lgamma", Intrinsic::not_intrinsic}, - {"logabsgamma", Intrinsic::not_intrinsic}, - {"ceil", Intrinsic::ceil}, - {"__nv_ceil", Intrinsic::ceil}, - {"floor", Intrinsic::floor}, - {"fmod", Intrinsic::not_intrinsic}, - {"trunc", Intrinsic::trunc}, - {"round", Intrinsic::round}, - {"rint", Intrinsic::rint}, - {"nearbyint", Intrinsic::nearbyint}, - {"remainder", Intrinsic::not_intrinsic}, - {"copysign", Intrinsic::copysign}, - {"nextafter", Intrinsic::not_intrinsic}, - {"nexttoward", Intrinsic::not_intrinsic}, - {"fdim", Intrinsic::not_intrinsic}, - {"fmax", Intrinsic::maxnum}, - {"fmin", Intrinsic::minnum}, - {"fabs", Intrinsic::fabs}, - {"fma", Intrinsic::fma}, - {"ilogb", Intrinsic::not_intrinsic}, - {"scalbn", Intrinsic::not_intrinsic}, - {"scalbln", Intrinsic::not_intrinsic}, - {"powi", Intrinsic::powi}, - {"cabs", Intrinsic::not_intrinsic}, - {"ldexp", Intrinsic::not_intrinsic}, - {"fmod", Intrinsic::not_intrinsic}, - {"finite", Intrinsic::not_intrinsic}, - {"isinf", Intrinsic::not_intrinsic}, - {"isnan", Intrinsic::not_intrinsic}, - {"lround", Intrinsic::lround}, - {"llround", Intrinsic::llround}, - {"lrint", Intrinsic::lrint}, - {"llrint", Intrinsic::llrint}}; - -static bool isItaniumEncoding(StringRef S) { - // Itanium encoding requires 1 or 3 leading underscores, followed by 'Z'. - return startsWith(S, "_Z") || startsWith(S, "___Z"); -} - -bool dontAnalyze(StringRef str) { - if (isItaniumEncoding(str)) { - if (str.empty()) - return false; - - ItaniumPartialDemangler Parser; - char *data = (char *)malloc(str.size() + 1); - memcpy(data, str.data(), str.size()); - data[str.size()] = 0; - bool hasError = Parser.partialDemangle(data); - if (hasError) { - free(data); - return false; - } - - // auto basename = Parser.getFunctionBaseName(0, 0); - // auto base = Parser.getFunctionDeclContextName(0, 0); - // auto fn = Parser.getFunctionName(0, 0); - // llvm::errs() << " err: " << base << " - " << basename << " fn - " << fn - // << "\n"; - free(data); - } - return false; -} - -TypeAnalyzer::TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA, - uint8_t direction) - : MST(EnzymePrintType ? new ModuleSlotTracker(fn.Function->getParent()) - : nullptr), - notForAnalysis(getGuaranteedUnreachable(fn.Function)), intseen(), - fntypeinfo(fn), interprocedural(TA), direction(direction), Invalid(false), - PHIRecur(false), - TLI(TA.FAM.getResult(*fn.Function)), - DT(TA.FAM.getResult(*fn.Function)), - PDT(TA.FAM.getResult(*fn.Function)), - LI(TA.FAM.getResult(*fn.Function)), - SE(TA.FAM.getResult(*fn.Function)) { - - assert(fntypeinfo.KnownValues.size() == - fntypeinfo.Function->getFunctionType()->getNumParams()); - - // Add all instructions in the function - for (BasicBlock &BB : *fntypeinfo.Function) { - if (notForAnalysis.count(&BB)) - continue; - for (Instruction &I : BB) { - workList.insert(&I); - } - } - // Add all operands referenced in the function - // This is done to investigate any referenced globals/etc - for (BasicBlock &BB : *fntypeinfo.Function) { - for (Instruction &I : BB) { - for (auto &Op : I.operands()) { - addToWorkList(Op); - } - } - } -} - -TypeAnalyzer::TypeAnalyzer( - const FnTypeInfo &fn, TypeAnalysis &TA, - const llvm::SmallPtrSetImpl ¬ForAnalysis, - const TypeAnalyzer &Prev, uint8_t direction, bool PHIRecur) - : MST(Prev.MST), - notForAnalysis(notForAnalysis.begin(), notForAnalysis.end()), intseen(), - fntypeinfo(fn), interprocedural(TA), direction(direction), Invalid(false), - PHIRecur(PHIRecur), TLI(Prev.TLI), DT(Prev.DT), PDT(Prev.PDT), - LI(Prev.LI), SE(Prev.SE) { - assert(fntypeinfo.KnownValues.size() == - fntypeinfo.Function->getFunctionType()->getNumParams()); -} - -static SmallPtrSet -findLoopIndices(llvm::Value *val, LoopInfo &LI, DominatorTree &DT) { - if (isa(val)) - return {}; - if (auto CI = dyn_cast(val)) - return findLoopIndices(CI->getOperand(0), LI, DT); - if (auto CI = dyn_cast(val)) - return findLoopIndices(CI->getOperand(0), LI, DT); - if (auto bo = dyn_cast(val)) { - auto inset0 = findLoopIndices(bo->getOperand(0), LI, DT); - auto inset1 = findLoopIndices(bo->getOperand(1), LI, DT); - inset0.insert(inset1.begin(), inset1.end()); - return inset0; - } - if (auto LDI = dyn_cast(val)) { - if (auto AI = dyn_cast(LDI->getPointerOperand())) { - StoreInst *SI = nullptr; - bool failed = false; - for (auto u : AI->users()) { - if (auto SIu = dyn_cast(u)) { - if (SI && SIu->getValueOperand() == AI) { - failed = true; - break; - } - SI = SIu; - } else if (!isa(u)) { - if (!cast(u)->mayReadOrWriteMemory() && - cast(u)->use_empty()) - continue; - if (auto CI = dyn_cast(u)) { - if (auto F = CI->getCalledFunction()) { - auto funcName = F->getName(); - if (funcName == "__kmpc_for_static_init_4" || - funcName == "__kmpc_for_static_init_4u" || - funcName == "__kmpc_for_static_init_8" || - funcName == "__kmpc_for_static_init_8u") { - continue; - } - } - } - failed = true; - break; - } - } - if (SI && !failed && DT.dominates(SI, LDI)) { - return findLoopIndices(SI->getValueOperand(), LI, DT); - } - } - } - if (auto pn = dyn_cast(val)) { - auto L = LI.getLoopFor(pn->getParent()); - if (L && L->getHeader() == pn->getParent()) - return {pn->getParent()}; - SmallPtrSet ops; - for (unsigned i = 0; i < pn->getNumIncomingValues(); ++i) { - auto a = pn->getIncomingValue(i); - auto seti = findLoopIndices(a, LI, DT); - ops.insert(seti.begin(), seti.end()); - } - return ops; - } - return {}; -} - -std::set -FnTypeInfo::knownIntegralValues(llvm::Value *val, const DominatorTree &DT, - std::map> &intseen, - ScalarEvolution &SE) const { - if (auto constant = dyn_cast(val)) { -#if LLVM_VERSION_MAJOR > 14 - if (constant->getValue().getSignificantBits() > 64) - return {}; -#else - if (constant->getValue().getMinSignedBits() > 64) - return {}; -#endif - return {constant->getSExtValue()}; - } - - if (isa(val)) { - return {0}; - } - - assert(KnownValues.size() == Function->getFunctionType()->getNumParams()); - - if (auto arg = dyn_cast(val)) { - auto found = KnownValues.find(arg); - if (found == KnownValues.end()) { - for (const auto &pair : KnownValues) { - llvm::errs() << " KnownValues[" << *pair.first << "] - " - << pair.first->getParent()->getName() << "\n"; - } - llvm::errs() << " arg: " << *arg << " - " << arg->getParent()->getName() - << "\n"; - } - assert(found != KnownValues.end()); - return found->second; - } - - if (intseen.find(val) != intseen.end()) - return intseen[val]; - intseen[val] = {}; - - if (auto ci = dyn_cast(val)) { - intseen[val] = knownIntegralValues(ci->getOperand(0), DT, intseen, SE); - } - - auto insert = [&](int64_t v) { - if (intseen[val].size() == 0) { - intseen[val].insert(v); - } else { - if (intseen[val].size() == 1) { - if (abs(*intseen[val].begin()) > MaxIntOffset) { - if (abs(*intseen[val].begin()) > abs(v)) { - intseen[val].clear(); - intseen[val].insert(v); - } else { - return; - } - } else { - if (abs(v) > MaxIntOffset) { - return; - } else { - intseen[val].insert(v); - } - } - } else { - if (abs(v) > MaxIntOffset) { - return; - } else { - intseen[val].insert(v); - } - } - } - }; - if (auto II = dyn_cast(val)) { - switch (II->getIntrinsicID()) { -#if LLVM_VERSION_MAJOR >= 12 - case Intrinsic::abs: - for (auto val : - knownIntegralValues(II->getArgOperand(0), DT, intseen, SE)) - insert(abs(val)); - break; -#endif - case Intrinsic::nvvm_read_ptx_sreg_tid_x: - case Intrinsic::nvvm_read_ptx_sreg_tid_y: - case Intrinsic::nvvm_read_ptx_sreg_tid_z: - case Intrinsic::nvvm_read_ptx_sreg_ctaid_x: - case Intrinsic::nvvm_read_ptx_sreg_ctaid_y: - case Intrinsic::nvvm_read_ptx_sreg_ctaid_z: - case Intrinsic::amdgcn_workitem_id_x: - case Intrinsic::amdgcn_workitem_id_y: - case Intrinsic::amdgcn_workitem_id_z: - insert(0); - break; - default: - break; - } - } - if (auto LI = dyn_cast(val)) { - if (auto AI = dyn_cast(LI->getPointerOperand())) { - StoreInst *SI = nullptr; - bool failed = false; - for (auto u : AI->users()) { - if (auto SIu = dyn_cast(u)) { - if (SI && SIu->getValueOperand() == AI) { - failed = true; - break; - } - SI = SIu; - } else if (!isa(u)) { - if (!cast(u)->mayReadOrWriteMemory() && - cast(u)->use_empty()) - continue; - if (auto CI = dyn_cast(u)) { - if (auto F = CI->getCalledFunction()) { - auto funcName = F->getName(); - if (funcName == "__kmpc_for_static_init_4" || - funcName == "__kmpc_for_static_init_4u" || - funcName == "__kmpc_for_static_init_8" || - funcName == "__kmpc_for_static_init_8u") { - continue; - } - } - } - failed = true; - break; - } - } - if (SI && !failed && DT.dominates(SI, LI)) { - for (auto val : - knownIntegralValues(SI->getValueOperand(), DT, intseen, SE)) { - insert(val); - } - } - } - } - if (auto pn = dyn_cast(val)) { - if (SE.isSCEVable(pn->getType())) - if (auto S = dyn_cast(SE.getSCEV(pn))) { - if (auto StartC = dyn_cast(S->getStart())) { - auto L = S->getLoop(); - auto BE = SE.getBackedgeTakenCount(L); - if (BE != SE.getCouldNotCompute()) { - if (auto Iters = dyn_cast(BE)) { - uint64_t ival = Iters->getAPInt().getZExtValue(); - // If strict aliasing and the loop header does not dominate all - // blocks at low optimization levels the last "iteration" will - // actually exit leading to one extra backedge that would be wise - // to ignore. - if (EnzymeStrictAliasing) { - bool rotated = false; - BasicBlock *Latch = L->getLoopLatch(); - rotated = Latch && L->isLoopExiting(Latch); - if (!rotated) { - if (ival > 0) - ival--; - } - } - - uint64_t istart = 0; - - if (S->isAffine()) { - if (auto StepC = dyn_cast(S->getOperand(1))) { - APInt StartI = StartC->getAPInt(); - APInt A = StepC->getAPInt(); - - if (A.sle(-1)) { - A = -A; - StartI = -StartI; - } - - if (A.sge(1)) { - if (StartI.sge(MaxIntOffset)) { - ival = std::min(ival, (uint64_t)0); - } else { - ival = std::min( - ival, - (MaxIntOffset - StartI + A).udiv(A).getZExtValue()); - } - - if (StartI.slt(-MaxIntOffset)) { - istart = std::max( - istart, - (-MaxIntOffset - StartI).udiv(A).getZExtValue()); - } - - } else { - ival = std::min(ival, (uint64_t)0); - } - } else { - ival = std::min(ival, (uint64_t)0); - } - } - - for (uint64_t i = istart; i <= ival; i++) { - if (auto Val = dyn_cast(S->evaluateAtIteration( - SE.getConstant(Iters->getType(), i, /*signed*/ false), - SE))) { - insert(Val->getAPInt().getSExtValue()); - } - } - return intseen[val]; - } - } - } - } - - for (unsigned i = 0; i < pn->getNumIncomingValues(); ++i) { - auto a = pn->getIncomingValue(i); - auto b = pn->getIncomingBlock(i); - - // do not consider loop incoming edges - if (pn->getParent() == b || DT.dominates(pn, b)) { - continue; - } - - auto inset = knownIntegralValues(a, DT, intseen, SE); - - // TODO this here is not fully justified yet - for (auto pval : inset) { - if (pval < 20 && pval > -20) { - insert(pval); - } - } - - // if we are an iteration variable, suppose that it could be zero in that - // range - // TODO: could actually check the range intercepts 0 - if (auto bo = dyn_cast(a)) { - if (bo->getOperand(0) == pn || bo->getOperand(1) == pn) { - if (bo->getOpcode() == BinaryOperator::Add || - bo->getOpcode() == BinaryOperator::Sub) { - insert(0); - } - } - } - } - return intseen[val]; - } - - if (auto bo = dyn_cast(val)) { - auto inset0 = knownIntegralValues(bo->getOperand(0), DT, intseen, SE); - auto inset1 = knownIntegralValues(bo->getOperand(1), DT, intseen, SE); - if (bo->getOpcode() == BinaryOperator::Mul) { - - if (inset0.size() == 1 || inset1.size() == 1) { - for (auto val0 : inset0) { - for (auto val1 : inset1) { - - insert(val0 * val1); - } - } - } - if (inset0.count(0) || inset1.count(0)) { - intseen[val].insert(0); - } - } - - if (bo->getOpcode() == BinaryOperator::Add) { - if (inset0.size() == 1 || inset1.size() == 1) { - for (auto val0 : inset0) { - for (auto val1 : inset1) { - insert(val0 + val1); - } - } - } - } - if (bo->getOpcode() == BinaryOperator::Sub) { - if (inset0.size() == 1 || inset1.size() == 1) { - for (auto val0 : inset0) { - for (auto val1 : inset1) { - insert(val0 - val1); - } - } - } - } - - if (bo->getOpcode() == BinaryOperator::SDiv) { - if (inset0.size() == 1 || inset1.size() == 1) { - for (auto val0 : inset0) { - for (auto val1 : inset1) { - insert(val0 / val1); - } - } - } - } - - if (bo->getOpcode() == BinaryOperator::Shl) { - if (inset0.size() == 1 || inset1.size() == 1) { - for (auto val0 : inset0) { - for (auto val1 : inset1) { - insert(val0 << val1); - } - } - } - } - - // TODO note C++ doesnt guarantee behavior of >> being arithmetic or logical - // and should replace with llvm apint internal - if (bo->getOpcode() == BinaryOperator::AShr || - bo->getOpcode() == BinaryOperator::LShr) { - if (inset0.size() == 1 || inset1.size() == 1) { - for (auto val0 : inset0) { - for (auto val1 : inset1) { - insert(val0 >> val1); - } - } - } - } - } - - return intseen[val]; -} - -/// Given a constant value, deduce any type information applicable -void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA, - std::map &analysis) { - auto found = analysis.find(Val); - if (found != analysis.end()) - return; - - auto &DL = TA.fntypeinfo.Function->getParent()->getDataLayout(); - - // Undefined value is an anything everywhere - if (isa(Val) || isa(Val)) { - analysis[Val].insert({-1}, BaseType::Anything); - return; - } - - // Null pointer is a pointer to anything, everywhere - if (isa(Val)) { - TypeTree &Result = analysis[Val]; - Result.insert({-1}, BaseType::Pointer); - Result.insert({-1, -1}, BaseType::Anything); - return; - } - - // Known pointers are pointers at offset 0 - if (isa(Val) || isa(Val)) { - analysis[Val].insert({-1}, BaseType::Pointer); - return; - } - - // Any constants == 0 are considered Anything - // other floats are assumed to be that type - if (auto FP = dyn_cast(Val)) { - if (FP->isExactlyValue(0.0)) { - analysis[Val].insert({-1}, BaseType::Anything); - return; - } - analysis[Val].insert({-1}, ConcreteType(FP->getType()->getScalarType())); - return; - } - - if (auto ci = dyn_cast(Val)) { - // Constants in range [1, 4096] are assumed to be integral since - // any float or pointers they may represent are ill-formed - if (!ci->isNegative() && ci->getLimitedValue() >= 1 && - ci->getLimitedValue() <= 4096) { - analysis[Val].insert({-1}, BaseType::Integer); - return; - } - - // Constants explicitly marked as negative that aren't -1 are considered - // integral - if (ci->isNegative() && !ci->isMinusOne()) { - analysis[Val].insert({-1}, BaseType::Integer); - return; - } - - // Values of size < 16 (half size) are considered integral - // since they cannot possibly represent a float or pointer - if (cast(ci->getType())->getBitWidth() < 16) { - analysis[Val].insert({-1}, BaseType::Integer); - return; - } - // All other constant-ints could be any type - analysis[Val].insert({-1}, BaseType::Anything); - return; - } - - // Type of an aggregate is the aggregation of - // the subtypes - if (auto CA = dyn_cast(Val)) { - TypeTree &Result = analysis[Val]; - for (unsigned i = 0, size = CA->getNumOperands(); i < size; ++i) { - assert(TA.fntypeinfo.Function); - auto Op = CA->getOperand(i); - // TODO check this for i1 constant aggregates packing/etc - auto ObjSize = (TA.fntypeinfo.Function->getParent() - ->getDataLayout() - .getTypeSizeInBits(Op->getType()) + - 7) / - 8; - - Value *vec[2] = { - ConstantInt::get(Type::getInt64Ty(Val->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(Val->getContext()), i), - }; - auto g2 = GetElementPtrInst::Create( - Val->getType(), - UndefValue::get(PointerType::getUnqual(Val->getType())), vec); - APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(DL, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - delete g2; - - int Off = (int)ai.getLimitedValue(); - if (auto VT = dyn_cast(Val->getType())) - if (VT->getElementType()->isIntegerTy(1)) - Off = i / 8; - - getConstantAnalysis(Op, TA, analysis); - auto mid = analysis[Op]; - if (TA.fntypeinfo.Function->getParent() - ->getDataLayout() - .getTypeSizeInBits(CA->getType()) >= 16) { - mid.ReplaceIntWithAnything(); - } - - Result |= mid.ShiftIndices(DL, /*init offset*/ 0, - /*maxSize*/ ObjSize, - /*addOffset*/ Off); - } - Result.CanonicalizeInPlace( - (TA.fntypeinfo.Function->getParent()->getDataLayout().getTypeSizeInBits( - CA->getType()) + - 7) / - 8, - DL); - return; - } - - // Type of an sequence is the aggregation of - // the subtypes - if (auto CD = dyn_cast(Val)) { - TypeTree &Result = analysis[Val]; - for (unsigned i = 0, size = CD->getNumElements(); i < size; ++i) { - assert(TA.fntypeinfo.Function); - auto Op = CD->getElementAsConstant(i); - // TODO check this for i1 constant aggregates packing/etc - auto ObjSize = (TA.fntypeinfo.Function->getParent() - ->getDataLayout() - .getTypeSizeInBits(Op->getType()) + - 7) / - 8; - - Value *vec[2] = { - ConstantInt::get(Type::getInt64Ty(Val->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(Val->getContext()), i), - }; - auto g2 = GetElementPtrInst::Create( - Val->getType(), - UndefValue::get(PointerType::getUnqual(Val->getType())), vec); - APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(DL, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - delete g2; - - int Off = (int)ai.getLimitedValue(); - - getConstantAnalysis(Op, TA, analysis); - auto mid = analysis[Op]; - if (TA.fntypeinfo.Function->getParent() - ->getDataLayout() - .getTypeSizeInBits(CD->getType()) >= 16) { - mid.ReplaceIntWithAnything(); - } - Result |= mid.ShiftIndices(DL, /*init offset*/ 0, - /*maxSize*/ ObjSize, - /*addOffset*/ Off); - - Result |= mid; - } - Result.CanonicalizeInPlace( - (TA.fntypeinfo.Function->getParent()->getDataLayout().getTypeSizeInBits( - CD->getType()) + - 7) / - 8, - DL); - return; - } - - // ConstantExprs are handled by considering the - // equivalent instruction - if (auto CE = dyn_cast(Val)) { - if (CE->isCast()) { - if (CE->getType()->isPointerTy() && isa(CE->getOperand(0))) { - analysis[Val] = TypeTree(BaseType::Anything).Only(-1, nullptr); - return; - } - getConstantAnalysis(CE->getOperand(0), TA, analysis); - analysis[Val] = analysis[CE->getOperand(0)]; - return; - } - if (CE->getOpcode() == Instruction::GetElementPtr) { - TA.visitGEPOperator(*cast(CE)); - return; - } - - auto I = CE->getAsInstruction(); - I->insertBefore(TA.fntypeinfo.Function->getEntryBlock().getTerminator()); - - // Just analyze this new "instruction" and none of the others - { - TypeAnalyzer tmpAnalysis(TA.fntypeinfo, TA.interprocedural, - TA.notForAnalysis, TA); - tmpAnalysis.visit(*I); - analysis[Val] = tmpAnalysis.getAnalysis(I); - - if (tmpAnalysis.workList.remove(I)) { - TA.workList.insert(CE); - } - } - - I->eraseFromParent(); - return; - } - - if (auto GV = dyn_cast(Val)) { - - if (GV->getName() == "__cxa_thread_atexit_impl") { - analysis[Val] = TypeTree(BaseType::Pointer).Only(-1, nullptr); - return; - } - - // from julia code - if (GV->getName() == "small_typeof" || GV->getName() == "jl_small_typeof") { - TypeTree T; - T.insert({-1}, BaseType::Pointer); - T.insert({-1, -1}, BaseType::Pointer); - analysis[Val] = T; - return; - } - - TypeTree &Result = analysis[Val]; - Result.insert({-1}, ConcreteType(BaseType::Pointer)); - - // A fixed constant global is a pointer to its initializer - if (GV->isConstant() && GV->hasInitializer()) { - getConstantAnalysis(GV->getInitializer(), TA, analysis); - Result |= analysis[GV->getInitializer()].Only(-1, nullptr); - return; - } - if (!isa(GV->getValueType()) || - !cast(GV->getValueType())->isOpaque()) { - auto globalSize = (DL.getTypeSizeInBits(GV->getValueType()) + 7) / 8; - // Since halfs are 16bit (2 byte) and pointers are >=32bit (4 byte) any - // Single byte object must be integral - if (globalSize == 1) { - Result.insert({-1, -1}, ConcreteType(BaseType::Integer)); - return; - } - } - - // Otherwise, we simply know that this is a pointer, and - // not what it is a pointer to - return; - } - - // No other information can be ascertained - analysis[Val] = TypeTree(); - return; -} - -TypeTree TypeAnalyzer::getAnalysis(Value *Val) { - // Integers with fewer than 16 bits (size of half) - // must be integral, since it cannot possibly represent a float or pointer - if (!isa(Val) && Val->getType()->isIntegerTy() && - cast(Val->getType())->getBitWidth() < 16) - return TypeTree(BaseType::Integer).Only(-1, nullptr); - if (auto C = dyn_cast(Val)) { - getConstantAnalysis(C, *this, analysis); - return analysis[Val]; - } - - // Check that this value is from the function being analyzed - if (auto I = dyn_cast(Val)) { - if (I->getParent()->getParent() != fntypeinfo.Function) { - llvm::errs() << " function: " << *fntypeinfo.Function << "\n"; - llvm::errs() << " instParent: " << *I->getParent()->getParent() << "\n"; - llvm::errs() << " inst: " << *I << "\n"; - } - assert(I->getParent()->getParent() == fntypeinfo.Function); - } - if (auto Arg = dyn_cast(Val)) { - if (Arg->getParent() != fntypeinfo.Function) { - llvm::errs() << " function: " << *fntypeinfo.Function << "\n"; - llvm::errs() << " argParent: " << *Arg->getParent() << "\n"; - llvm::errs() << " arg: " << *Arg << "\n"; - } - assert(Arg->getParent() == fntypeinfo.Function); - } - - // Return current results - if (isa(Val) || isa(Val)) - return analysis[Val]; - - // Unhandled/unknown Value - llvm::errs() << "Error Unknown Value: " << *Val << "\n"; - assert(0 && "Error Unknown Value: "); - llvm_unreachable("Error Unknown Value: "); - // return TypeTree(); -} - -void TypeAnalyzer::updateAnalysis(Value *Val, ConcreteType Data, - Value *Origin) { - updateAnalysis(Val, TypeTree(Data), Origin); -} - -void TypeAnalyzer::updateAnalysis(Value *Val, BaseType Data, Value *Origin) { - updateAnalysis(Val, TypeTree(ConcreteType(Data)), Origin); -} - -void TypeAnalyzer::addToWorkList(Value *Val) { - // Only consider instructions/arguments - if (!isa(Val) && !isa(Val) && - !isa(Val) && !isa(Val)) - return; - - // Verify this value comes from the function being analyzed - if (auto I = dyn_cast(Val)) { - if (fntypeinfo.Function != I->getParent()->getParent()) - return; - if (notForAnalysis.count(I->getParent())) - return; - if (fntypeinfo.Function != I->getParent()->getParent()) { - llvm::errs() << "function: " << *fntypeinfo.Function << "\n"; - llvm::errs() << "instf: " << *I->getParent()->getParent() << "\n"; - llvm::errs() << "inst: " << *I << "\n"; - } - assert(fntypeinfo.Function == I->getParent()->getParent()); - } else if (auto Arg = dyn_cast(Val)) { - if (fntypeinfo.Function != Arg->getParent()) { - llvm::errs() << "fn: " << *fntypeinfo.Function << "\n"; - llvm::errs() << "argparen: " << *Arg->getParent() << "\n"; - llvm::errs() << "val: " << *Arg << "\n"; - } - assert(fntypeinfo.Function == Arg->getParent()); - } - - // Add to workList - workList.insert(Val); -} - -void TypeAnalyzer::updateAnalysis(Value *Val, TypeTree Data, Value *Origin) { - if (Val->getType()->isVoidTy()) - return; - // ConstantData's and Functions don't have analysis updated - // We don't do "Constant" as globals are "Constant" types - if (isa(Val) || isa(Val)) { - return; - } - - if (auto GV = dyn_cast(Val)) { - if (hasMetadata(GV, "enzyme_ta_norecur")) - return; - } - - if (auto CE = dyn_cast(Val)) { - if (CE->isCast() && isa(CE->getOperand(0))) { - return; - } - if (CE->getOpcode() == Instruction::GetElementPtr && - isa(CE->getOperand(0))) - return; - } - - if (auto I = dyn_cast(Val)) { - if (fntypeinfo.Function != I->getParent()->getParent()) { - llvm::errs() << "function: " << *fntypeinfo.Function << "\n"; - llvm::errs() << "instf: " << *I->getParent()->getParent() << "\n"; - llvm::errs() << "inst: " << *I << "\n"; - } - assert(fntypeinfo.Function == I->getParent()->getParent()); - assert(Origin); - if (!EnzymeStrictAliasing) { - if (auto OI = dyn_cast(Origin)) { - if (OI->getParent() != I->getParent() && - !PDT.dominates(OI->getParent(), I->getParent())) { - bool allocationWithAllUsersInBlock = false; - if (auto AI = dyn_cast(I)) { - allocationWithAllUsersInBlock = true; - for (auto U : AI->users()) { - auto P = cast(U)->getParent(); - if (P == OI->getParent()) - continue; - if (PDT.dominates(OI->getParent(), P)) - continue; - allocationWithAllUsersInBlock = false; - break; - } - } - if (!allocationWithAllUsersInBlock) { - if (EnzymePrintType) { - llvm::errs() << " skipping update into "; - I->print(llvm::errs(), *MST); - llvm::errs() << " of " << Data.str() << " from "; - OI->print(llvm::errs(), *MST); - llvm::errs() << "\n"; - } - return; - } - } - } - } - } else if (auto Arg = dyn_cast(Val)) { - assert(fntypeinfo.Function == Arg->getParent()); - if (!EnzymeStrictAliasing) - if (auto OI = dyn_cast(Origin)) { - auto I = &*fntypeinfo.Function->getEntryBlock().begin(); - if (OI->getParent() != I->getParent() && - !PDT.dominates(OI->getParent(), I->getParent())) { - if (EnzymePrintType) { - llvm::errs() << " skipping update into "; - Arg->print(llvm::errs(), *MST); - llvm::errs() << " of " << Data.str() << " from "; - OI->print(llvm::errs(), *MST); - llvm::errs() << "\n"; - } - return; - } - } - } - - // Attempt to update the underlying analysis - bool LegalOr = true; - if (analysis.find(Val) == analysis.end() && isa(Val)) { - if (!isa(Val) || - cast(Val)->getOpcode() != Instruction::GetElementPtr) - getConstantAnalysis(cast(Val), *this, analysis); - } - - TypeTree prev = analysis[Val]; - - auto &DL = fntypeinfo.Function->getParent()->getDataLayout(); - auto RegSize = (DL.getTypeSizeInBits(Val->getType()) + 7) / 8; - Data.CanonicalizeInPlace(RegSize, DL); - bool Changed = - analysis[Val].checkedOrIn(Data, /*PointerIntSame*/ false, LegalOr); - - // Print the update being made, if requested - if (EnzymePrintType) { - llvm::errs() << "updating analysis of val: "; - Val->print(llvm::errs(), *MST); - llvm::errs() << " current: " << prev.str() << " new " << Data.str(); - if (Origin) { - llvm::errs() << " from "; - Origin->print(llvm::errs(), *MST); - } - llvm::errs() << " Changed=" << Changed << " legal=" << LegalOr << "\n"; - } - - if (!LegalOr) { - if (direction != BOTH) { - Invalid = true; - return; - } - std::string str; - raw_string_ostream ss(str); - if (!CustomErrorHandler) { - llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; - llvm::errs() << *fntypeinfo.Function << "\n"; - dump(ss); - } - ss << "Illegal updateAnalysis prev:" << prev.str() << " new: " << Data.str() - << "\n"; - ss << "val: " << *Val; - if (Origin) - ss << " origin=" << *Origin; - - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(Val), ErrorType::IllegalTypeAnalysis, - (void *)this, wrap(Origin), nullptr); - } - if (auto I = dyn_cast(Val)) { - EmitFailure("IllegalUpdateAnalysis", I->getDebugLoc(), I, ss.str()); - exit(1); - } else if (auto I = dyn_cast_or_null(Origin)) { - EmitFailure("IllegalUpdateAnalysis", I->getDebugLoc(), I, ss.str()); - exit(1); - } else { - llvm::errs() << ss.str() << "\n"; - } - report_fatal_error("Performed illegal updateAnalysis"); - } - - if (Changed) { - - if (auto GV = dyn_cast(Val)) { - if (GV->getValueType()->isSized()) { - auto Size = (DL.getTypeSizeInBits(GV->getValueType()) + 7) / 8; - Data = analysis[Val].Lookup(Size, DL).Only(-1, nullptr); - Data.insert({-1}, BaseType::Pointer); - analysis[Val] = Data; - Origin = Val; - } - } - // Add val so it can explicitly propagate this new info, if able to - if (Val != Origin) - addToWorkList(Val); - - // Add users and operands of the value so they can update from the new - // operand/use - for (User *U : Val->users()) { - if (U != Origin) { - - if (auto I = dyn_cast(U)) { - if (fntypeinfo.Function != I->getParent()->getParent()) { - continue; - } - } - - addToWorkList(U); - - // per the handling of phi's - if (auto BO = dyn_cast(U)) { - for (User *U2 : BO->users()) { - if (isa(U2) && U2 != Origin) { - addToWorkList(U2); - } - } - } - } - } - - if (User *US = dyn_cast(Val)) { - for (Value *Op : US->operands()) { - if (Op != Origin) { - addToWorkList(Op); - } - } - } - } -} - -/// Analyze type info given by the arguments, possibly adding to work queue -void TypeAnalyzer::prepareArgs() { - // Propagate input type information for arguments - for (auto &pair : fntypeinfo.Arguments) { - assert(pair.first->getParent() == fntypeinfo.Function); - updateAnalysis(pair.first, pair.second, pair.first); - } - - // Get type and other information about argument - // getAnalysis may add more information so this - // is necessary/useful - for (Argument &Arg : fntypeinfo.Function->args()) { - updateAnalysis(&Arg, getAnalysis(&Arg), &Arg); - } - - // Propagate return value type information - for (BasicBlock &BB : *fntypeinfo.Function) { - for (Instruction &I : BB) { - if (ReturnInst *RI = dyn_cast(&I)) { - if (Value *RV = RI->getReturnValue()) { - updateAnalysis(RV, fntypeinfo.Return, RV); - updateAnalysis(RV, getAnalysis(RV), RV); - } - } - } - } -} - -/// Analyze type info given by the TBAA, possibly adding to work queue -void TypeAnalyzer::considerTBAA() { - auto &DL = fntypeinfo.Function->getParent()->getDataLayout(); - - for (BasicBlock &BB : *fntypeinfo.Function) { - for (Instruction &I : BB) { - if (auto MD = I.getMetadata("enzyme_type")) { - auto TT = TypeTree::fromMD(MD); - - auto RegSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8; - for (const auto &pair : TT.getMapping()) { - if (pair.first[0] != -1) { - if ((size_t)pair.first[0] >= RegSize) { - llvm::errs() << " bad enzyme_type " << TT.str() - << " RegSize=" << RegSize << " I:" << I << "\n"; - llvm::report_fatal_error("Canonicalization failed"); - } - } - } - updateAnalysis(&I, TT, &I); - } - - if (CallBase *call = dyn_cast(&I)) { -#if LLVM_VERSION_MAJOR >= 14 - size_t num_args = call->arg_size(); -#else - size_t num_args = call->getNumArgOperands(); -#endif - - if (call->getAttributes().hasAttribute(AttributeList::ReturnIndex, - "enzyme_type")) { - auto attr = call->getAttributes().getAttribute( - AttributeList::ReturnIndex, "enzyme_type"); - auto TT = - TypeTree::parse(attr.getValueAsString(), call->getContext()); - - auto RegSize = I.getType()->isVoidTy() - ? 0 - : (DL.getTypeSizeInBits(I.getType()) + 7) / 8; - for (const auto &pair : TT.getMapping()) { - if (pair.first[0] != -1) { - if ((size_t)pair.first[0] >= RegSize) { - llvm::errs() << " bad enzyme_type " << TT.str() - << " RegSize=" << RegSize << " I:" << I << "\n"; - llvm::report_fatal_error("Canonicalization failed"); - } - } - } - updateAnalysis(call, TT, call); - } - for (size_t i = 0; i < num_args; i++) { - if (call->getAttributes().hasParamAttr(i, "enzyme_type")) { - auto attr = call->getAttributes().getParamAttr(i, "enzyme_type"); - auto TT = - TypeTree::parse(attr.getValueAsString(), call->getContext()); - auto RegSize = I.getType()->isVoidTy() - ? 0 - : (DL.getTypeSizeInBits(I.getType()) + 7) / 8; - for (const auto &pair : TT.getMapping()) { - if (pair.first[0] != -1) { - if ((size_t)pair.first[0] >= RegSize) { - llvm::errs() << " bad enzyme_type " << TT.str() - << " RegSize=" << RegSize << " I:" << I << "\n"; - llvm::report_fatal_error("Canonicalization failed"); - } - } - } - updateAnalysis(call->getArgOperand(i), TT, call); - } - } - - Function *F = call->getCalledFunction(); - - if (F) { - if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, - "enzyme_type")) { - auto attr = F->getAttributes().getAttribute( - AttributeList::ReturnIndex, "enzyme_type"); - auto TT = - TypeTree::parse(attr.getValueAsString(), call->getContext()); - auto RegSize = I.getType()->isVoidTy() - ? 0 - : (DL.getTypeSizeInBits(I.getType()) + 7) / 8; - for (const auto &pair : TT.getMapping()) { - if (pair.first[0] != -1) { - if ((size_t)pair.first[0] >= RegSize) { - llvm::errs() << " bad enzyme_type " << TT.str() - << " RegSize=" << RegSize << " I:" << I << "\n"; - llvm::report_fatal_error("Canonicalization failed"); - } - } - } - updateAnalysis(call, TT, call); - } - size_t f_num_args = F->arg_size(); - for (size_t i = 0; i < f_num_args; i++) { - if (F->getAttributes().hasParamAttr(i, "enzyme_type")) { - auto attr = F->getAttributes().getParamAttr(i, "enzyme_type"); - auto TT = - TypeTree::parse(attr.getValueAsString(), call->getContext()); - auto RegSize = I.getType()->isVoidTy() - ? 0 - : (DL.getTypeSizeInBits(I.getType()) + 7) / 8; - for (const auto &pair : TT.getMapping()) { - if (pair.first[0] != -1) { - if ((size_t)pair.first[0] >= RegSize) { - llvm::errs() - << " bad enzyme_type " << TT.str() - << " RegSize=" << RegSize << " I:" << I << "\n"; - llvm::report_fatal_error("Canonicalization failed"); - } - } - } - updateAnalysis(call->getArgOperand(i), TT, call); - } - } - } - - if (auto castinst = dyn_cast(call->getCalledOperand())) { - if (castinst->isCast()) - if (auto fn = dyn_cast(castinst->getOperand(0))) { - F = fn; - } - } - if (F && F->getName().contains("__enzyme_float")) { - assert(num_args == 1 || num_args == 2); - assert(call->getArgOperand(0)->getType()->isPointerTy()); - TypeTree TT; - ssize_t num = 1; - if (num_args == 2) { - assert(isa(call->getArgOperand(1))); - auto CI = cast(call->getArgOperand(1)); - if (CI->isNegative()) - num = -1; - else - num = CI->getLimitedValue(); - } - if (num == -1) - TT.insert({(int)num}, Type::getFloatTy(call->getContext())); - else - for (size_t i = 0; i < (size_t)num; i += 4) - TT.insert({(int)i}, Type::getFloatTy(call->getContext())); - TT.insert({}, BaseType::Pointer); - updateAnalysis(call->getOperand(0), TT.Only(-1, call), call); - } - if (F && F->getName().contains("__enzyme_double")) { - assert(num_args == 1 || num_args == 2); - assert(call->getArgOperand(0)->getType()->isPointerTy()); - TypeTree TT; - size_t num = 1; - if (num_args == 2) { - assert(isa(call->getArgOperand(1))); - num = cast(call->getArgOperand(1))->getLimitedValue(); - } - for (size_t i = 0; i < num; i += 8) - TT.insert({(int)i}, Type::getDoubleTy(call->getContext())); - TT.insert({}, BaseType::Pointer); - updateAnalysis(call->getOperand(0), TT.Only(-1, call), call); - } - if (F && F->getName().contains("__enzyme_integer")) { - assert(num_args == 1 || num_args == 2); - assert(call->getArgOperand(0)->getType()->isPointerTy()); - size_t num = 1; - if (num_args == 2) { - assert(isa(call->getArgOperand(1))); - num = cast(call->getArgOperand(1))->getLimitedValue(); - } - TypeTree TT; - for (size_t i = 0; i < num; i++) - TT.insert({(int)i}, BaseType::Integer); - TT.insert({}, BaseType::Pointer); - updateAnalysis(call->getOperand(0), TT.Only(-1, call), call); - } - if (F && F->getName().contains("__enzyme_pointer")) { - assert(num_args == 1 || num_args == 2); - assert(call->getArgOperand(0)->getType()->isPointerTy()); - TypeTree TT; - size_t num = 1; - if (num_args == 2) { - assert(isa(call->getArgOperand(1))); - num = cast(call->getArgOperand(1))->getLimitedValue(); - } - for (size_t i = 0; i < num; - i += ((DL.getPointerSizeInBits() + 7) / 8)) - TT.insert({(int)i}, BaseType::Pointer); - TT.insert({}, BaseType::Pointer); - updateAnalysis(call->getOperand(0), TT.Only(-1, call), call); - } - if (F) { - StringSet<> JuliaKnownTypes = {"julia.gc_alloc_obj", - "jl_alloc_array_1d", - "jl_alloc_array_2d", - "jl_alloc_array_3d", - "ijl_alloc_array_1d", - "ijl_alloc_array_2d", - "ijl_alloc_array_3d", - "jl_gc_alloc_typed", - "ijl_gc_alloc_typed", - "jl_alloc_genericmemory", - "ijl_alloc_genericmemory", - "jl_new_array", - "ijl_new_array"}; - if (JuliaKnownTypes.count(F->getName())) { - visitCallBase(*call); - continue; - } - } - } - - TypeTree vdptr = parseTBAA(I, DL, MST); - - // If we don't have any useful information, - // don't bother updating - if (!vdptr.isKnownPastPointer()) - continue; - - if (CallBase *call = dyn_cast(&I)) { - if (call->getCalledFunction() && - (call->getCalledFunction()->getIntrinsicID() == Intrinsic::memcpy || - call->getCalledFunction()->getIntrinsicID() == - Intrinsic::memmove)) { - int64_t copySize = 1; - for (auto val : fntypeinfo.knownIntegralValues(call->getOperand(2), - DT, intseen, SE)) { - copySize = max(copySize, val); - } - TypeTree update = - vdptr - .ShiftIndices(DL, /*init offset*/ 0, - /*max size*/ copySize, /*new offset*/ 0) - .Only(-1, call); - - updateAnalysis(call->getOperand(0), update, call); - updateAnalysis(call->getOperand(1), update, call); - continue; - } else if (call->getCalledFunction() && - (call->getCalledFunction()->getIntrinsicID() == - Intrinsic::memset || - call->getCalledFunction()->getName() == - "memset_pattern16")) { - int64_t copySize = 1; - for (auto val : fntypeinfo.knownIntegralValues(call->getOperand(2), - DT, intseen, SE)) { - copySize = max(copySize, val); - } - TypeTree update = - vdptr - .ShiftIndices(DL, /*init offset*/ 0, - /*max size*/ copySize, /*new offset*/ 0) - .Only(-1, call); - - updateAnalysis(call->getOperand(0), update, call); - continue; - } else if (call->getCalledFunction() && - call->getCalledFunction()->getIntrinsicID() == - Intrinsic::masked_gather) { - auto VT = cast(call->getType()); - auto LoadSize = (DL.getTypeSizeInBits(VT) + 7) / 8; - TypeTree req = vdptr.Only(-1, call); - updateAnalysis(call, req.Lookup(LoadSize, DL), call); - // TODO use mask to propagate up to relevant pointer - } else if (call->getCalledFunction() && - call->getCalledFunction()->getIntrinsicID() == - Intrinsic::masked_scatter) { - // TODO use mask to propagate up to relevant pointer - } else if (call->getCalledFunction() && - call->getCalledFunction()->getIntrinsicID() == - Intrinsic::masked_load) { - auto VT = cast(call->getType()); - auto LoadSize = (DL.getTypeSizeInBits(VT) + 7) / 8; - TypeTree req = vdptr.Only(-1, call); - updateAnalysis(call, req.Lookup(LoadSize, DL), call); - // TODO use mask to propagate up to relevant pointer - } else if (call->getCalledFunction() && - call->getCalledFunction()->getIntrinsicID() == - Intrinsic::masked_store) { - // TODO use mask to propagate up to relevant pointer - } else if (call->getType()->isPointerTy()) { - updateAnalysis(call, vdptr.Only(-1, call), call); - } else { - llvm::errs() << " unknown tbaa call instruction user inst: " << I - << " vdptr: " << vdptr.str() << "\n"; - } - } else if (auto SI = dyn_cast(&I)) { - auto StoreSize = - (DL.getTypeSizeInBits(SI->getValueOperand()->getType()) + 7) / 8; - updateAnalysis(SI->getPointerOperand(), - vdptr - // Don't propagate "Anything" into ptr - .PurgeAnything() - // Cut off any values outside of store - .ShiftIndices(DL, /*init offset*/ 0, - /*max size*/ StoreSize, - /*new offset*/ 0) - .Only(-1, SI), - SI); - TypeTree req = vdptr.Only(-1, SI); - updateAnalysis(SI->getValueOperand(), req.Lookup(StoreSize, DL), SI); - } else if (auto LI = dyn_cast(&I)) { - auto LoadSize = (DL.getTypeSizeInBits(LI->getType()) + 7) / 8; - updateAnalysis(LI->getPointerOperand(), - vdptr - // Don't propagate "Anything" into ptr - .PurgeAnything() - // Cut off any values outside of load - .ShiftIndices(DL, /*init offset*/ 0, - /*max size*/ LoadSize, - /*new offset*/ 0) - .Only(-1, LI), - LI); - TypeTree req = vdptr.Only(-1, LI); - updateAnalysis(LI, req.Lookup(LoadSize, DL), LI); - } else { - llvm::errs() << " inst: " << I << " vdptr: " << vdptr.str() << "\n"; - assert(0 && "unknown tbaa instruction user"); - llvm_unreachable("unknown tbaa instruction user"); - } - } - } -} - -void TypeAnalyzer::runPHIHypotheses() { - if (PHIRecur) - return; - bool Changed; - do { - Changed = false; - for (BasicBlock &BB : *fntypeinfo.Function) { - for (Instruction &inst : BB) { - if (PHINode *phi = dyn_cast(&inst)) { - if (direction & DOWN && phi->getType()->isIntOrIntVectorTy() && - !getAnalysis(phi).isKnown()) { - // Assume that this is an integer, does that mean we can prove that - // the incoming operands are integral - - TypeAnalyzer tmpAnalysis(fntypeinfo, interprocedural, - notForAnalysis, *this, DOWN, - /*PHIRecur*/ true); - tmpAnalysis.intseen = intseen; - tmpAnalysis.analysis = analysis; - tmpAnalysis.analysis[phi] = - TypeTree(BaseType::Integer).Only(-1, phi); - for (auto U : phi->users()) { - if (auto I = dyn_cast(U)) { - tmpAnalysis.visit(*I); - } - } - tmpAnalysis.run(); - if (!tmpAnalysis.Invalid) { - TypeTree Result = tmpAnalysis.getAnalysis(phi); - for (auto &op : phi->incoming_values()) { - Result &= tmpAnalysis.getAnalysis(op); - } - if (Result == TypeTree(BaseType::Integer).Only(-1, phi) || - Result == TypeTree(BaseType::Anything).Only(-1, phi)) { - updateAnalysis(phi, Result, phi); - for (auto &pair : tmpAnalysis.analysis) { - updateAnalysis(pair.first, pair.second, phi); - } - Changed = true; - } - } - } - - if (direction & DOWN && phi->getType()->isFPOrFPVectorTy() && - !getAnalysis(phi).isKnown()) { - // Assume that this is an integer, does that mean we can prove that - // the incoming operands are integral - TypeAnalyzer tmpAnalysis(fntypeinfo, interprocedural, - notForAnalysis, *this, DOWN, - /*PHIRecur*/ true); - tmpAnalysis.intseen = intseen; - tmpAnalysis.analysis = analysis; - tmpAnalysis.analysis[phi] = - TypeTree(phi->getType()->getScalarType()).Only(-1, phi); - for (auto U : phi->users()) { - if (auto I = dyn_cast(U)) { - tmpAnalysis.visit(*I); - } - } - tmpAnalysis.run(); - if (!tmpAnalysis.Invalid) { - TypeTree Result = tmpAnalysis.getAnalysis(phi); - for (auto &op : phi->incoming_values()) { - Result &= tmpAnalysis.getAnalysis(op); - } - if (Result == - TypeTree(phi->getType()->getScalarType()).Only(-1, phi) || - Result == TypeTree(BaseType::Anything).Only(-1, phi)) { - updateAnalysis(phi, Result, phi); - for (auto &pair : tmpAnalysis.analysis) { - updateAnalysis(pair.first, pair.second, phi); - } - Changed = true; - } - } - } - } - } - } - } while (Changed); - return; -} - -void TypeAnalyzer::run() { - - TimeTraceScope timeScope("Type Analysis", fntypeinfo.Function->getName()); - - // This function runs a full round of type analysis. - // This works by doing two stages of analysis, - // with a "deduced integer types for unused" values - // sandwiched in-between. This is done because we only - // perform that check for values without types. - // - // For performance reasons in each round of type analysis - // only analyze any call instances after all other potential - // updates have been done. This is to minimize the number - // of expensive interprocedural analyses - std::deque pendingCalls; - - do { - while (!Invalid && workList.size()) { - auto todo = *workList.begin(); - workList.erase(workList.begin()); - if (auto call = dyn_cast(todo)) { - StringRef funcName = getFuncNameFromCall(call); - auto ci = getFunctionFromCall(call); - if (ci && !ci->empty()) { - if (interprocedural.CustomRules.find(funcName) == - interprocedural.CustomRules.end()) { - pendingCalls.push_back(call); - continue; - } - } - } - visitValue(*todo); - } - - if (pendingCalls.size() > 0) { - auto todo = pendingCalls.front(); - pendingCalls.pop_front(); - visitValue(*todo); - continue; - } else - break; - - } while (1); - - runPHIHypotheses(); - - do { - - while (!Invalid && workList.size()) { - auto todo = *workList.begin(); - workList.erase(workList.begin()); - if (auto ci = dyn_cast(todo)) { - pendingCalls.push_back(ci); - continue; - } - visitValue(*todo); - } - - if (pendingCalls.size() > 0) { - auto todo = pendingCalls.front(); - pendingCalls.pop_front(); - visitValue(*todo); - continue; - } else - break; - - } while (1); -} - -void TypeAnalyzer::visitValue(Value &val) { - if (auto CE = dyn_cast(&val)) { - visitConstantExpr(*CE); - } - - if (isa(&val)) { - return; - } - - if (!isa(&val) && !isa(&val)) - return; - - if (auto *FPMO = dyn_cast(&val)) { - if (FPMO->getOpcode() == Instruction::FNeg) { - Value *op = FPMO->getOperand(0); - auto ty = op->getType()->getScalarType(); - assert(ty->isFloatingPointTy()); - ConcreteType dt(ty); - updateAnalysis(op, TypeTree(ty).Only(-1, nullptr), - cast(&val)); - updateAnalysis(FPMO, TypeTree(ty).Only(-1, nullptr), - cast(&val)); - return; - } - } - - if (auto inst = dyn_cast(&val)) { - visit(*inst); - } -} - -void TypeAnalyzer::visitConstantExpr(ConstantExpr &CE) { - if (CE.isCast()) { - if (direction & DOWN) - updateAnalysis(&CE, getAnalysis(CE.getOperand(0)), &CE); - if (direction & UP) - updateAnalysis(CE.getOperand(0), getAnalysis(&CE), &CE); - return; - } - if (CE.getOpcode() == Instruction::GetElementPtr) { - visitGEPOperator(*cast(&CE)); - return; - } - auto I = CE.getAsInstruction(); - I->insertBefore(fntypeinfo.Function->getEntryBlock().getTerminator()); - analysis[I] = analysis[&CE]; - visit(*I); - updateAnalysis(&CE, analysis[I], &CE); - analysis.erase(I); - if (workList.remove(I)) { - workList.insert(&CE); - } - I->eraseFromParent(); -} - -void TypeAnalyzer::visitCmpInst(CmpInst &cmp) { - // No directionality check needed as always true - updateAnalysis(&cmp, TypeTree(BaseType::Integer).Only(-1, &cmp), &cmp); - if (direction & UP) { - updateAnalysis( - cmp.getOperand(0), - TypeTree(getAnalysis(cmp.getOperand(1)).Inner0().PurgeAnything()) - .Only(-1, &cmp), - &cmp); - updateAnalysis( - cmp.getOperand(1), - TypeTree(getAnalysis(cmp.getOperand(0)).Inner0().PurgeAnything()) - .Only(-1, &cmp), - &cmp); - } -} - -void TypeAnalyzer::visitAllocaInst(AllocaInst &I) { - // No directionality check needed as always true - updateAnalysis(I.getArraySize(), TypeTree(BaseType::Integer).Only(-1, &I), - &I); - - auto ptr = TypeTree(BaseType::Pointer); - - if (auto CI = dyn_cast(I.getArraySize())) { - auto &DL = I.getParent()->getParent()->getParent()->getDataLayout(); - auto LoadSize = CI->getZExtValue() * - (DL.getTypeSizeInBits(I.getAllocatedType()) + 7) / 8; - // Only propagate mappings in range that aren't "Anything" into the pointer - ptr |= getAnalysis(&I).Lookup(LoadSize, DL); - } - updateAnalysis(&I, ptr.Only(-1, &I), &I); -} - -void TypeAnalyzer::visitLoadInst(LoadInst &I) { - auto &DL = I.getParent()->getParent()->getParent()->getDataLayout(); - auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8; - - if (direction & UP) { - // Only propagate mappings in range that aren't "Anything" into the pointer - auto ptr = getAnalysis(&I).PurgeAnything().ShiftIndices( - DL, /*start*/ 0, LoadSize, /*addOffset*/ 0); - ptr |= TypeTree(BaseType::Pointer); - updateAnalysis(I.getOperand(0), ptr.Only(-1, &I), &I); - } - if (direction & DOWN) - updateAnalysis(&I, getAnalysis(I.getOperand(0)).Lookup(LoadSize, DL), &I); -} - -void TypeAnalyzer::visitStoreInst(StoreInst &I) { - auto &DL = I.getParent()->getParent()->getParent()->getDataLayout(); - auto StoreSize = - (DL.getTypeSizeInBits(I.getValueOperand()->getType()) + 7) / 8; - - // Rust specific rule, if storing an integer equal to the alignment - // of a store, assuming nothing (or assume it is a pointer) - // https://doc.rust-lang.org/src/core/ptr/non_null.rs.html#70-78 - if (RustTypeRules) - if (auto CI = dyn_cast(I.getValueOperand())) { - auto alignment = I.getAlign().value(); - - if (CI->getLimitedValue() == alignment) { - return; - } - } - - // Only propagate mappings in range that aren't "Anything" into the pointer - auto ptr = TypeTree(BaseType::Pointer); - auto purged = getAnalysis(I.getValueOperand()) - .PurgeAnything() - .ShiftIndices(DL, /*start*/ 0, StoreSize, /*addOffset*/ 0) - .ReplaceMinus(); - ptr |= purged; - - if (direction & UP) { - updateAnalysis(I.getPointerOperand(), ptr.Only(-1, &I), &I); - - // Note that we also must purge anything from ptr => value in case we store - // to a nullptr which has type [-1, -1]: Anything. While storing to a - // nullptr is obviously bad, this doesn't mean the value we're storing is an - // Anything - updateAnalysis(I.getValueOperand(), - getAnalysis(I.getPointerOperand()) - .PurgeAnything() - .Lookup(StoreSize, DL), - &I); - } -} - -// Give a list of sets representing the legal set of values at a given index -// return a set of all possible combinations of those values -template -std::set> getSet(ArrayRef> todo, size_t idx) { - assert(idx < todo.size()); - std::set> out; - if (idx == 0) { - for (auto val : todo[0]) { - out.insert({val}); - } - return out; - } - - auto old = getSet(todo, idx - 1); - for (const auto &oldv : old) { - for (auto val : todo[idx]) { - auto nex = oldv; - nex.push_back(val); - out.insert(nex); - } - } - return out; -} - -void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { - visitGEPOperator(*cast(&gep)); -} - -void TypeAnalyzer::visitGEPOperator(GEPOperator &gep) { - auto inst = dyn_cast(&gep); - if (isa(gep.getPointerOperand())) { - updateAnalysis(&gep, TypeTree(BaseType::Anything).Only(-1, inst), &gep); - return; - } - if (isa(gep.getPointerOperand())) { - bool nonZero = false; - bool legal = true; - for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) { - auto ind = I->get(); - if (auto CI = dyn_cast(ind)) { - if (!CI->isZero()) { - nonZero = true; - continue; - } - } - auto CT = getAnalysis(ind).Inner0(); - if (CT == BaseType::Integer) { - continue; - } - legal = false; - break; - } - if (legal && nonZero) { - updateAnalysis(&gep, TypeTree(BaseType::Integer).Only(-1, inst), &gep); - return; - } - } - if (auto GV = dyn_cast(gep.getPointerOperand())) { - // from julia code, do not propagate int to operands - if (GV->getName() == "small_typeof" || GV->getName() == "jl_small_typeof") { - TypeTree T; - T.insert({-1}, BaseType::Pointer); - T.insert({-1, -1}, BaseType::Pointer); - updateAnalysis(&gep, T, &gep); - return; - } - } - - if (gep.idx_begin() == gep.idx_end()) { - if (direction & DOWN) - updateAnalysis(&gep, getAnalysis(gep.getPointerOperand()), &gep); - if (direction & UP) - updateAnalysis(gep.getPointerOperand(), getAnalysis(&gep), &gep); - return; - } - - auto &DL = fntypeinfo.Function->getParent()->getDataLayout(); - - auto pointerAnalysis = getAnalysis(gep.getPointerOperand()); - - // If we know that the pointer operand is indeed a pointer, then the indicies - // must be integers Note that we can't do this if we don't know the pointer - // operand is a pointer since doing 1[pointer] is legal - // sadly this still may not work since (nullptr)[fn] => fn where fn is - // pointer and not int (whereas nullptr is a pointer) However if we are - // inbounds you are only allowed to have nullptr[0] or nullptr[nullptr], - // making this valid - // Assuming nullptr[nullptr] doesn't occur in practice, the following - // is valid. We could make it always valid by checking the pointer - // operand explicitly is a pointer. - if (direction & UP) { - if (gep.isInBounds() || (!EnzymeStrictAliasing && - pointerAnalysis.Inner0() == BaseType::Pointer && - getAnalysis(&gep).Inner0() == BaseType::Pointer)) { - for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) { - auto ind = I->get(); - updateAnalysis(ind, TypeTree(BaseType::Integer).Only(-1, inst), &gep); - } - } - } - - // If one of these is known to be a pointer, propagate it if either in bounds - // or all operands are integral/unknown - bool pointerPropagate = gep.isInBounds(); - if (!pointerPropagate) { - bool allIntegral = true; - for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) { - auto ind = I->get(); - auto CT = getAnalysis(ind).Inner0(); - if (CT != BaseType::Integer && CT != BaseType::Anything) { - allIntegral = false; - break; - } - } - if (allIntegral) - pointerPropagate = true; - } - - if (!pointerPropagate) - return; - - if (direction & DOWN) { - bool legal = true; - auto keepMinus = pointerAnalysis.KeepMinusOne(legal); - if (!legal) { - if (CustomErrorHandler) - CustomErrorHandler("Could not keep minus one", wrap(&gep), - ErrorType::IllegalTypeAnalysis, this, nullptr, - nullptr); - else { - dump(); - llvm::errs() << " could not perform minus one for gep'd: " << gep - << "\n"; - } - } - updateAnalysis(&gep, keepMinus, &gep); - updateAnalysis(&gep, TypeTree(pointerAnalysis.Inner0()).Only(-1, inst), - &gep); - } - if (direction & UP) - updateAnalysis(gep.getPointerOperand(), - TypeTree(getAnalysis(&gep).Inner0()).Only(-1, inst), &gep); - - TypeTree upTree; - TypeTree downTree; - - TypeTree gepData0; - TypeTree pointerData0; - if (direction & UP) - gepData0 = getAnalysis(&gep).Data0(); - if (direction & DOWN) - pointerData0 = pointerAnalysis.Data0(); - - auto BitWidth = DL.getIndexSizeInBits(gep.getPointerAddressSpace()); - - APInt constOffset(BitWidth, 0); - -#if LLVM_VERSION_MAJOR >= 20 - SmallMapVector VariableOffsets; -#else - MapVector VariableOffsets; -#endif - bool legalOffset = - collectOffset(&gep, DL, BitWidth, VariableOffsets, constOffset); - (void)legalOffset; - assert(legalOffset); - - SmallVector, 4> idnext; - - SmallPtrSet previousLoopInductionHeaders; - { - Value *ptr = gep.getPointerOperand(); - while (true) { - if (auto gepop = dyn_cast(ptr)) { - for (auto I = gepop->idx_begin(), E = gepop->idx_end(); I != E; I++) { - for (auto loopInd : findLoopIndices(*I, LI, DT)) { - previousLoopInductionHeaders.insert(loopInd); - } - } - ptr = gepop->getPointerOperand(); - continue; - } - if (auto CI = dyn_cast(ptr)) { - ptr = CI->getOperand(0); - continue; - } - break; - } - } - - for (auto &pair : VariableOffsets) { - auto a = pair.first; - auto iset = fntypeinfo.knownIntegralValues(a, DT, intseen, SE); - std::set vset; - for (auto i : iset) { - // Don't consider negative indices of gep - if (i < 0) - continue; - vset.insert(i); - } - if (vset.size() == 0) - return; - - // If seen the same variable before with > 1 option, we will accidentally - // do an offset for [option1, option2] * oldOffset + [option1, option2] * - // newOffset - // instead of [option1, option2] * (oldOffset + newOffset). - // In this case abort - // TODO, in the future, mutually compute the offset together. - if (vset.size() != 1) { - for (auto loopInd : findLoopIndices(pair.first, LI, DT)) - if (previousLoopInductionHeaders.count(loopInd)) - return; - } - idnext.push_back(vset); - } - - // Stores pair ([whether first offset is zero], offset) - std::vector> offsets; - Value *firstIdx = *gep.idx_begin(); - if (VariableOffsets.size() == 0) { - bool firstIsZero = cast(firstIdx)->getLimitedValue() == 0; - offsets.emplace_back(firstIsZero, (int)constOffset.getLimitedValue()); - } else { - bool firstIsZero = false; - if (auto CI = dyn_cast(firstIdx)) - firstIsZero = CI->getLimitedValue() == 0; - for (auto vec : getSet(idnext, idnext.size() - 1)) { - APInt nextOffset = constOffset; - for (auto [varpair, const_value] : llvm::zip(VariableOffsets, vec)) { - nextOffset += varpair.second * const_value; - if (varpair.first == firstIdx) - firstIsZero = const_value == 0; - } - offsets.emplace_back(firstIsZero, (int)nextOffset.getLimitedValue()); - } - } - - bool seenIdx = false; - - for (auto [firstIsZero, off] : offsets) { - // TODO also allow negative offsets - if (off < 0) - continue; - - int maxSize = -1; - if (firstIsZero) { - maxSize = DL.getTypeAllocSizeInBits(gep.getResultElementType()) / 8; - } - - if (direction & DOWN) { - auto shft = - pointerData0.ShiftIndices(DL, /*init offset*/ off, - /*max size*/ maxSize, /*newoffset*/ 0); - if (seenIdx) - downTree &= shft; - else - downTree = shft; - } - - if (direction & UP) { - auto shft = gepData0.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1, - /*new offset*/ off); - if (seenIdx) - upTree |= shft; - else - upTree = shft; - } - seenIdx = true; - } - if (direction & DOWN) - updateAnalysis(&gep, downTree.Only(-1, inst), &gep); - if (direction & UP) - updateAnalysis(gep.getPointerOperand(), upTree.Only(-1, inst), &gep); -} - -void TypeAnalyzer::visitPHINode(PHINode &phi) { - if (direction & UP) { - TypeTree upVal = getAnalysis(&phi); - // only propagate anything's up if there is one - // incoming value - Value *seen = phi.getIncomingValue(0); - for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) { - if (seen != phi.getIncomingValue(i)) { - seen = nullptr; - break; - } - } - - if (!seen) { - upVal = upVal.PurgeAnything(); - } - - if (EnzymeStrictAliasing || seen) { - auto L = LI.getLoopFor(phi.getParent()); - bool isHeader = L && L->getHeader() == phi.getParent(); - for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) { - if (!isHeader || !L->contains(phi.getIncomingBlock(i))) { - updateAnalysis(phi.getIncomingValue(i), upVal, &phi); - } - } - } else { - if (EnzymePrintType) { - for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) { - llvm::errs() << " skipping update into "; - phi.getIncomingValue(i)->print(llvm::errs(), *MST); - llvm::errs() << " of " << upVal.str() << " from "; - phi.print(llvm::errs(), *MST); - llvm::errs() << "\n"; - } - } - } - } - - assert(phi.getNumIncomingValues() > 0); - - // TODO generalize this (and for recursive, etc) - - for (int i = 0; i < 2; i++) { - - std::deque vals; - std::set seen{&phi}; - for (auto &op : phi.incoming_values()) { - vals.push_back(op); - } - SmallVector bos; - - // Unique values that propagate into this phi - SmallVector UniqueValues; - - while (vals.size()) { - Value *todo = vals.front(); - vals.pop_front(); - - if (auto bo = dyn_cast(todo)) { - if (bo->getOpcode() == BinaryOperator::Add) { - if (isa(bo->getOperand(0))) { - bos.push_back(bo); - todo = bo->getOperand(1); - } - if (isa(bo->getOperand(1))) { - bos.push_back(bo); - todo = bo->getOperand(0); - } - } - } - - if (seen.count(todo)) - continue; - seen.insert(todo); - - if (auto nphi = dyn_cast(todo)) { - if (i == 0) { - for (auto &op : nphi->incoming_values()) { - vals.push_back(op); - } - continue; - } - } - if (auto sel = dyn_cast(todo)) { - vals.push_back(sel->getOperand(1)); - vals.push_back(sel->getOperand(2)); - continue; - } - UniqueValues.push_back(todo); - } - - TypeTree PhiTypes; - bool set = false; - - for (size_t i = 0, size = UniqueValues.size(); i < size; ++i) { - TypeTree newData = getAnalysis(UniqueValues[i]); - if (UniqueValues.size() == 2) { - if (auto BO = dyn_cast(UniqueValues[i])) { - if (BO->getOpcode() == BinaryOperator::Add || - BO->getOpcode() == BinaryOperator::Mul) { - TypeTree otherData = getAnalysis(UniqueValues[1 - i]); - // If we are adding/muling to a constant to derive this, we can - // assume it to be an integer rather than Anything - if (isa(UniqueValues[1 - i])) { - otherData = TypeTree(BaseType::Integer).Only(-1, &phi); - } - if (BO->getOperand(0) == &phi) { - set = true; - PhiTypes = otherData; - bool Legal = true; - PhiTypes.binopIn(Legal, getAnalysis(BO->getOperand(1)), - BO->getOpcode()); - if (!Legal) { - std::string str; - raw_string_ostream ss(str); - if (!CustomErrorHandler) { - llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; - llvm::errs() << *fntypeinfo.Function << "\n"; - dump(ss); - } - ss << "Illegal updateBinop Analysis " << *BO << "\n"; - ss << "Illegal binopIn(0): " << *BO - << " lhs: " << PhiTypes.str() - << " rhs: " << getAnalysis(BO->getOperand(0)).str() << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(BO), - ErrorType::IllegalTypeAnalysis, - (void *)this, wrap(BO), nullptr); - } - EmitFailure("IllegalUpdateAnalysis", BO->getDebugLoc(), BO, - ss.str()); - report_fatal_error("Performed illegal updateAnalysis"); - } - break; - } else if (BO->getOperand(1) == &phi) { - set = true; - PhiTypes = getAnalysis(BO->getOperand(0)); - bool Legal = true; - PhiTypes.binopIn(Legal, otherData, BO->getOpcode()); - if (!Legal) { - std::string str; - raw_string_ostream ss(str); - if (!CustomErrorHandler) { - llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; - llvm::errs() << *fntypeinfo.Function << "\n"; - dump(ss); - } - ss << "Illegal updateBinop Analysis " << *BO << "\n"; - ss << "Illegal binopIn(1): " << *BO - << " lhs: " << PhiTypes.str() << " rhs: " << otherData.str() - << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(BO), - ErrorType::IllegalTypeAnalysis, - (void *)this, wrap(BO), nullptr); - } - EmitFailure("IllegalUpdateAnalysis", BO->getDebugLoc(), BO, - ss.str()); - report_fatal_error("Performed illegal updateAnalysis"); - } - break; - } - } else if (BO->getOpcode() == BinaryOperator::Sub) { - // Repeated subtraction from a type X yields the type X back - TypeTree otherData = getAnalysis(UniqueValues[1 - i]); - // If we are subtracting from a constant to derive this, we can - // assume it to be an integer rather than Anything - if (isa(UniqueValues[1 - i])) { - otherData = TypeTree(BaseType::Integer).Only(-1, &phi); - } - if (BO->getOperand(0) == &phi) { - set = true; - PhiTypes = otherData; - break; - } - } - } - } - if (set) { - PhiTypes &= newData; - // TODO consider the or of anything (see selectinst) - // however, this cannot be done yet for risk of turning - // phi's that add floats into anything - // PhiTypes |= newData.JustAnything(); - } else { - set = true; - PhiTypes = newData; - } - } - - assert(set); - // If we are only add / sub / etc to derive a value based off 0 - // we can start by assuming the type of 0 is integer rather - // than assuming it could be anything (per null) - if (bos.size() > 0 && UniqueValues.size() == 1 && - isa(UniqueValues[0]) && - (cast(UniqueValues[0])->isZero() || - cast(UniqueValues[0])->isOne())) { - PhiTypes = TypeTree(BaseType::Integer).Only(-1, &phi); - } - for (BinaryOperator *bo : bos) { - TypeTree vd1 = isa(bo->getOperand(0)) - ? getAnalysis(bo->getOperand(0)).Data0() - : PhiTypes.Data0(); - TypeTree vd2 = isa(bo->getOperand(1)) - ? getAnalysis(bo->getOperand(1)).Data0() - : PhiTypes.Data0(); - bool Legal = true; - vd1.binopIn(Legal, vd2, bo->getOpcode()); - if (!Legal) { - std::string str; - raw_string_ostream ss(str); - if (!CustomErrorHandler) { - llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; - llvm::errs() << *fntypeinfo.Function << "\n"; - dump(ss); - } - ss << "Illegal updateBinop Analysis " << *bo << "\n"; - ss << "Illegal binopIn(consts): " << *bo << " lhs: " << vd1.str() - << " rhs: " << vd2.str() << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(bo), - ErrorType::IllegalTypeAnalysis, (void *)this, - wrap(bo), nullptr); - } - EmitFailure("IllegalUpdateAnalysis", bo->getDebugLoc(), bo, ss.str()); - report_fatal_error("Performed illegal updateAnalysis"); - } - PhiTypes &= vd1.Only(bo->getType()->isIntegerTy() ? -1 : 0, &phi); - } - - if (direction & DOWN) { - if (phi.getType()->isIntOrIntVectorTy() && - PhiTypes.Inner0() == BaseType::Anything) { - if (mustRemainInteger(&phi)) { - PhiTypes = TypeTree(BaseType::Integer).Only(-1, &phi); - } - } - updateAnalysis(&phi, PhiTypes, &phi); - } - } -} - -void TypeAnalyzer::visitTruncInst(TruncInst &I) { - auto &DL = fntypeinfo.Function->getParent()->getDataLayout(); - size_t inSize = (DL.getTypeSizeInBits(I.getOperand(0)->getType()) + 7) / 8; - size_t outSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8; - if (direction & DOWN) - if (outSize != 1) - updateAnalysis(&I, - getAnalysis(I.getOperand(0)) - .ShiftIndices(DL, /*off*/ 0, inSize, /*addOffset*/ 0) - .ShiftIndices(DL, /*off*/ 0, outSize, /*addOffset*/ 0), - &I); - // Don't propagate up a trunc float -> i8 - if (direction & UP) - if (outSize != 1 || inSize == 1) - updateAnalysis( - I.getOperand(0), - getAnalysis(&I).ShiftIndices(DL, /*off*/ 0, outSize, /*addOffset*/ 0), - &I); -} - -void TypeAnalyzer::visitZExtInst(ZExtInst &I) { - if (direction & DOWN) { - TypeTree Result; - if (cast(I.getOperand(0)->getType()->getScalarType()) - ->getBitWidth() == 1) { - Result = TypeTree(BaseType::Anything).Only(-1, &I); - } else { - Result = getAnalysis(I.getOperand(0)); - } - - if (I.getType()->isIntOrIntVectorTy() && - Result.Inner0() == BaseType::Anything) { - if (mustRemainInteger(&I)) { - Result = TypeTree(BaseType::Integer).Only(-1, &I); - } - } - updateAnalysis(&I, Result, &I); - } - if (direction & UP) { - updateAnalysis(I.getOperand(0), getAnalysis(&I), &I); - } -} - -void TypeAnalyzer::visitSExtInst(SExtInst &I) { - // This is only legal on integer types [not pointers per sign] - // nor floatings points. Likewise, there's no direction check - // necessary since this is always valid. - updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); - updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I), &I); -} - -void TypeAnalyzer::visitAddrSpaceCastInst(AddrSpaceCastInst &I) { - if (direction & DOWN) - updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I); - if (direction & UP) - updateAnalysis(I.getOperand(0), getAnalysis(&I), &I); -} - -void TypeAnalyzer::visitFPExtInst(FPExtInst &I) { - // No direction check as always true - updateAnalysis( - &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I), - &I); - updateAnalysis( - I.getOperand(0), - TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType())) - .Only(-1, &I), - &I); -} - -void TypeAnalyzer::visitFPTruncInst(FPTruncInst &I) { - // No direction check as always true - updateAnalysis( - &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I), - &I); - updateAnalysis( - I.getOperand(0), - TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType())) - .Only(-1, &I), - &I); -} - -void TypeAnalyzer::visitFPToUIInst(FPToUIInst &I) { - // No direction check as always true - updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); - updateAnalysis( - I.getOperand(0), - TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType())) - .Only(-1, &I), - &I); -} - -void TypeAnalyzer::visitFPToSIInst(FPToSIInst &I) { - // No direction check as always true - updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); - updateAnalysis( - I.getOperand(0), - TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType())) - .Only(-1, &I), - &I); -} - -void TypeAnalyzer::visitUIToFPInst(UIToFPInst &I) { - // No direction check as always true - updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I), &I); - updateAnalysis( - &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I), - &I); -} - -void TypeAnalyzer::visitSIToFPInst(SIToFPInst &I) { - // No direction check as always true - updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I), &I); - updateAnalysis( - &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I), - &I); -} - -void TypeAnalyzer::visitPtrToIntInst(PtrToIntInst &I) { - // Note it is illegal to assume here that either is a pointer or an int - if (direction & DOWN) - updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I); - if (direction & UP) - updateAnalysis(I.getOperand(0), getAnalysis(&I), &I); -} - -void TypeAnalyzer::visitIntToPtrInst(IntToPtrInst &I) { - // Note it is illegal to assume here that either is a pointer or an int - if (direction & DOWN) { - if (isa(I.getOperand(0))) { - updateAnalysis(&I, TypeTree(BaseType::Anything).Only(-1, &I), &I); - } else { - updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I); - } - } - if (direction & UP) - updateAnalysis(I.getOperand(0), getAnalysis(&I), &I); -} - -void TypeAnalyzer::visitFreezeInst(FreezeInst &I) { - if (direction & DOWN) - updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I); - if (direction & UP) - updateAnalysis(I.getOperand(0), getAnalysis(&I), &I); -} - -void TypeAnalyzer::visitBitCastInst(BitCastInst &I) { - if (direction & DOWN) - updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I); - if (direction & UP) - updateAnalysis(I.getOperand(0), getAnalysis(&I), &I); -} - -void TypeAnalyzer::visitSelectInst(SelectInst &I) { - if (direction & UP) { - auto Data = getAnalysis(&I).PurgeAnything(); - if (EnzymeStrictAliasing || (I.getTrueValue() == I.getFalseValue())) { - updateAnalysis(I.getTrueValue(), Data, &I); - updateAnalysis(I.getFalseValue(), Data, &I); - } else { - if (EnzymePrintType) { - llvm::errs() << " skipping update into "; - I.getTrueValue()->print(llvm::errs(), *MST); - llvm::errs() << " of " << Data.str() << " from "; - I.print(llvm::errs(), *MST); - llvm::errs() << "\n"; - llvm::errs() << " skipping update into "; - I.getFalseValue()->print(llvm::errs(), *MST); - llvm::errs() << " of " << Data.str() << " from "; - I.print(llvm::errs(), *MST); - llvm::errs() << "\n"; - } - } - } - if (direction & DOWN) { - // special case for min/max result is still that operand [even if something - // is 0] - if (auto cmpI = dyn_cast(I.getCondition())) { - // is relational equiv to not is equality - if (!cmpI->isEquality()) - if ((cmpI->getOperand(0) == I.getTrueValue() && - cmpI->getOperand(1) == I.getFalseValue()) || - (cmpI->getOperand(1) == I.getTrueValue() && - cmpI->getOperand(0) == I.getFalseValue())) { - auto vd = getAnalysis(I.getTrueValue()).Inner0(); - vd &= getAnalysis(I.getFalseValue()).Inner0(); - if (vd.isKnown()) { - updateAnalysis(&I, TypeTree(vd).Only(-1, &I), &I); - return; - } - } - } - // If getTrueValue and getFalseValue are the same type (per the and) - // it is safe to assume the result is as well - TypeTree vd = getAnalysis(I.getTrueValue()).PurgeAnything(); - vd &= getAnalysis(I.getFalseValue()).PurgeAnything(); - - // A regular and operation, however is not sufficient. One of the operands - // could be anything whereas the other is concrete, resulting in the - // concrete type (e.g. select true, anything(0), integer(i64)) This is not - // correct as the result of the select could always be anything (e.g. if it - // is a pointer). As a result, explicitly or in any anything values - // TODO this should be propagated elsewhere as well (specifically returns, - // phi) - TypeTree any = getAnalysis(I.getTrueValue()).JustAnything(); - any &= getAnalysis(I.getFalseValue()).JustAnything(); - vd |= any; - updateAnalysis(&I, vd, &I); - } -} - -void TypeAnalyzer::visitExtractElementInst(ExtractElementInst &I) { - updateAnalysis(I.getIndexOperand(), BaseType::Integer, &I); - - auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); - VectorType *vecType = cast(I.getVectorOperand()->getType()); - - size_t bitsize = dl.getTypeSizeInBits(vecType->getElementType()); - size_t size = (bitsize + 7) / 8; - - if (auto CI = dyn_cast(I.getIndexOperand())) { - size_t off = (CI->getZExtValue() * bitsize) / 8; - - if (direction & DOWN) - updateAnalysis(&I, - getAnalysis(I.getVectorOperand()) - .ShiftIndices(dl, off, size, /*addOffset*/ 0), - &I); - - if (direction & UP) - updateAnalysis(I.getVectorOperand(), - getAnalysis(&I).ShiftIndices(dl, 0, size, off), &I); - - } else { - if (direction & DOWN) { - TypeTree vecAnalysis = getAnalysis(I.getVectorOperand()); - // TODO merge of anythings (see selectinst) - TypeTree res = vecAnalysis.Lookup(size, dl); - updateAnalysis(&I, res.Only(-1, &I), &I); - } - if (direction & UP) { - // propagated upward to unknown location, no analysis - // can be updated - } - } -} - -void TypeAnalyzer::visitInsertElementInst(InsertElementInst &I) { - updateAnalysis(I.getOperand(2), TypeTree(BaseType::Integer).Only(-1, &I), &I); - - auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); - VectorType *vecType = cast(I.getOperand(0)->getType()); - if (vecType->getElementType()->isIntegerTy(1)) { - if (direction & UP) { - updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I), - &I); - updateAnalysis(I.getOperand(1), TypeTree(BaseType::Integer).Only(-1, &I), - &I); - } - if (direction & DOWN) { - updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); - } - return; - } -#if LLVM_VERSION_MAJOR >= 12 - assert(!vecType->getElementCount().isScalable()); - size_t numElems = vecType->getElementCount().getKnownMinValue(); -#else - size_t numElems = vecType->getNumElements(); -#endif - size_t size = (dl.getTypeSizeInBits(vecType->getElementType()) + 7) / 8; - size_t vecSize = (dl.getTypeSizeInBits(vecType) + 7) / 8; - - if (auto CI = dyn_cast(I.getOperand(2))) { - size_t off = CI->getZExtValue() * size; - - if (direction & UP) - updateAnalysis(I.getOperand(0), - getAnalysis(&I).Clear(off, off + size, vecSize), &I); - - if (direction & UP) - updateAnalysis(I.getOperand(1), - getAnalysis(&I).ShiftIndices(dl, off, size, 0), &I); - - if (direction & DOWN) { - auto new_res = - getAnalysis(I.getOperand(0)).Clear(off, off + size, vecSize); - auto shifted = - getAnalysis(I.getOperand(1)).ShiftIndices(dl, 0, size, off); - new_res |= shifted; - updateAnalysis(&I, new_res, &I); - } - } else { - if (direction & DOWN) { - auto new_res = getAnalysis(I.getOperand(0)); - auto inserted = getAnalysis(I.getOperand(1)); - // TODO merge of anythings (see selectinst) - for (size_t i = 0; i < numElems; ++i) - new_res &= inserted.ShiftIndices(dl, 0, size, size * i); - updateAnalysis(&I, new_res, &I); - } - } -} - -void TypeAnalyzer::visitShuffleVectorInst(ShuffleVectorInst &I) { - // See selectinst type propagation rule for a description - // of the ncessity and correctness of this rule. - VectorType *resType = cast(I.getType()); - - auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); - - const size_t lhs = 0; - const size_t rhs = 1; - -#if LLVM_VERSION_MAJOR >= 12 - assert(!cast(I.getOperand(lhs)->getType()) - ->getElementCount() - .isScalable()); - size_t numFirst = cast(I.getOperand(lhs)->getType()) - ->getElementCount() - .getKnownMinValue(); -#else - size_t numFirst = - cast(I.getOperand(lhs)->getType())->getNumElements(); -#endif - size_t size = (dl.getTypeSizeInBits(resType->getElementType()) + 7) / 8; - - auto mask = I.getShuffleMask(); - - TypeTree result; // = getAnalysis(&I); - for (size_t i = 0; i < mask.size(); ++i) { - int newOff; - { - Value *vec[2] = {ConstantInt::get(Type::getInt64Ty(I.getContext()), 0), - ConstantInt::get(Type::getInt64Ty(I.getContext()), i)}; - auto ud = - UndefValue::get(PointerType::getUnqual(I.getOperand(0)->getType())); - auto g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec); - APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(dl, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - delete g2; - newOff = (int)ai.getLimitedValue(); - // there is a bug in LLVM, this is the correct offset - if (cast(I.getOperand(lhs)->getType()) - ->getElementType() - ->isIntegerTy(1)) { - newOff = i / 8; - } - } -#if LLVM_VERSION_MAJOR > 16 - if (mask[i] == PoisonMaskElem) -#elif LLVM_VERSION_MAJOR >= 12 - if (mask[i] == UndefMaskElem) -#else - if (mask[i] == -1) -#endif - { - if (direction & DOWN) { - result |= TypeTree(BaseType::Anything) - .Only(-1, &I) - .ShiftIndices(dl, 0, size, newOff); - } - } else { - if ((size_t)mask[i] < numFirst) { - Value *vec[2] = { - ConstantInt::get(Type::getInt64Ty(I.getContext()), 0), - ConstantInt::get(Type::getInt64Ty(I.getContext()), mask[i])}; - auto ud = - UndefValue::get(PointerType::getUnqual(I.getOperand(0)->getType())); - auto g2 = - GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec); - APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(dl, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - int oldOff = (int)ai.getLimitedValue(); - // there is a bug in LLVM, this is the correct offset - if (cast(I.getOperand(lhs)->getType()) - ->getElementType() - ->isIntegerTy(1)) { - oldOff = mask[i] / 8; - } - delete g2; - if (direction & UP) { - updateAnalysis(I.getOperand(lhs), - getAnalysis(&I).ShiftIndices(dl, newOff, size, oldOff), - &I); - } - if (direction & DOWN) { - result |= getAnalysis(I.getOperand(lhs)) - .ShiftIndices(dl, oldOff, size, newOff); - } - } else { - Value *vec[2] = {ConstantInt::get(Type::getInt64Ty(I.getContext()), 0), - ConstantInt::get(Type::getInt64Ty(I.getContext()), - mask[i] - numFirst)}; - auto ud = - UndefValue::get(PointerType::getUnqual(I.getOperand(0)->getType())); - auto g2 = - GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec); - APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(dl, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - int oldOff = (int)ai.getLimitedValue(); - // there is a bug in LLVM, this is the correct offset - if (cast(I.getOperand(lhs)->getType()) - ->getElementType() - ->isIntegerTy(1)) { - oldOff = (mask[i] - numFirst) / 8; - } - delete g2; - if (direction & UP) { - updateAnalysis(I.getOperand(rhs), - getAnalysis(&I).ShiftIndices(dl, newOff, size, oldOff), - &I); - } - if (direction & DOWN) { - result |= getAnalysis(I.getOperand(rhs)) - .ShiftIndices(dl, oldOff, size, newOff); - } - } - } - } - - if (direction & DOWN) { - updateAnalysis(&I, result, &I); - } -} - -void TypeAnalyzer::visitExtractValueInst(ExtractValueInst &I) { - auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); - SmallVector vec; - vec.push_back(ConstantInt::get(Type::getInt64Ty(I.getContext()), 0)); - for (auto ind : I.indices()) { - vec.push_back(ConstantInt::get(Type::getInt32Ty(I.getContext()), ind)); - } - auto ud = UndefValue::get(PointerType::getUnqual(I.getOperand(0)->getType())); - auto g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec); - APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(dl, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - delete g2; - - int off = (int)ai.getLimitedValue(); - int size = dl.getTypeSizeInBits(I.getType()) / 8; - - if (direction & DOWN) - updateAnalysis(&I, - getAnalysis(I.getOperand(0)) - .ShiftIndices(dl, off, size, /*addOffset*/ 0), - &I); - - if (direction & UP) - updateAnalysis(I.getOperand(0), - getAnalysis(&I).ShiftIndices(dl, 0, size, off), &I); -} - -void TypeAnalyzer::visitInsertValueInst(InsertValueInst &I) { - auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); - SmallVector vec = { - ConstantInt::get(Type::getInt64Ty(I.getContext()), 0)}; - for (auto ind : I.indices()) { - vec.push_back(ConstantInt::get(Type::getInt32Ty(I.getContext()), ind)); - } - auto ud = UndefValue::get(PointerType::getUnqual(I.getOperand(0)->getType())); - auto g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec); - APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(dl, ai); - delete g2; - // Using destructor rather than eraseFromParent - // as g2 has no parent - - // Compute the offset at the next logical element [e.g. adding 1 to the last - // index, carrying the value on overflow] - for (ssize_t i = vec.size() - 1; i >= 0; i--) { - auto CI = cast(vec[i]); - auto val = CI->getZExtValue(); - if (i == 0) { - vec[i] = ConstantInt::get(CI->getType(), val + 1); - break; - } - auto subTy = GetElementPtrInst::getIndexedType( - I.getOperand(0)->getType(), ArrayRef(vec).slice(0, i)); - if (auto ST = dyn_cast(subTy)) { - if (val + 1 == ST->getNumElements()) { - vec.erase(vec.begin() + i, vec.end()); - continue; - } - vec[i] = ConstantInt::get(CI->getType(), val + 1); - break; - } else { - auto AT = cast(subTy); - if (val + 1 == AT->getNumElements()) { - vec.erase(vec.begin() + i, vec.end()); - continue; - } - vec[i] = ConstantInt::get(CI->getType(), val + 1); - break; - } - } - g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec); - APInt aiend(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(dl, aiend); - delete g2; - - int off = (int)ai.getLimitedValue(); - - int agg_size = (dl.getTypeSizeInBits(I.getType()) + 7) / 8; - int ins_size = (int)(aiend - ai).getLimitedValue(); - int ins2_size = - (dl.getTypeSizeInBits(I.getInsertedValueOperand()->getType()) + 7) / 8; - - if (direction & UP) - updateAnalysis(I.getAggregateOperand(), - getAnalysis(&I).Clear(off, off + ins_size, agg_size), &I); - if (direction & UP) - updateAnalysis(I.getInsertedValueOperand(), - getAnalysis(&I).ShiftIndices(dl, off, ins2_size, 0), &I); - auto new_res = - getAnalysis(I.getAggregateOperand()).Clear(off, off + ins_size, agg_size); - auto shifted = getAnalysis(I.getInsertedValueOperand()) - .ShiftIndices(dl, 0, ins_size, off); - new_res |= shifted; - if (direction & DOWN) - updateAnalysis(&I, new_res, &I); -} - -void TypeAnalyzer::dump(llvm::raw_ostream &ss) { - ss << "\n"; - // We don't care about correct MD node numbering here. - ModuleSlotTracker MST(fntypeinfo.Function->getParent(), - /*ShouldInitializeAllMetadata*/ false); - for (auto &pair : analysis) { - if (auto F = dyn_cast(pair.first)) - ss << "@" << F->getName(); - else - pair.first->print(ss, MST); - ss << ": " << pair.second.str() - << ", intvals: " << to_string(knownIntegralValues(pair.first)) << "\n"; - } - ss << "\n"; -} - -void TypeAnalyzer::visitAtomicRMWInst(llvm::AtomicRMWInst &I) { - Value *Args[2] = {nullptr, I.getOperand(1)}; - TypeTree Ret = getAnalysis(&I); - auto &DL = I.getParent()->getParent()->getParent()->getDataLayout(); - auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8; - TypeTree LHS = getAnalysis(I.getOperand(0)).Lookup(LoadSize, DL); - TypeTree RHS = getAnalysis(I.getOperand(1)); - - switch (I.getOperation()) { - case AtomicRMWInst::Xchg: { - auto tmp = LHS; - LHS = RHS; - RHS = tmp; - bool Legal = true; - LHS.checkedOrIn(Ret, /*PointerIntSame*/ false, Legal); - if (!Legal) { - dump(); - llvm::errs() << I << "\n"; - llvm::errs() << "Illegal orIn: " << LHS.str() << " right: " << Ret.str() - << "\n"; - llvm::errs() << *I.getOperand(0) << " " - << getAnalysis(I.getOperand(0)).str() << "\n"; - llvm::errs() << *I.getOperand(1) << " " - << getAnalysis(I.getOperand(1)).str() << "\n"; - assert(0 && "Performed illegal visitAtomicRMWInst::orIn"); - llvm_unreachable("Performed illegal visitAtomicRMWInst::orIn"); - } - Ret = tmp; - break; - } - case AtomicRMWInst::Add: - visitBinaryOperation(DL, I.getType(), BinaryOperator::Add, Args, Ret, LHS, - RHS, &I); - break; - case AtomicRMWInst::Sub: - visitBinaryOperation(DL, I.getType(), BinaryOperator::Sub, Args, Ret, LHS, - RHS, &I); - break; - case AtomicRMWInst::And: - visitBinaryOperation(DL, I.getType(), BinaryOperator::And, Args, Ret, LHS, - RHS, &I); - break; - case AtomicRMWInst::Or: - visitBinaryOperation(DL, I.getType(), BinaryOperator::Or, Args, Ret, LHS, - RHS, &I); - break; - case AtomicRMWInst::Xor: - visitBinaryOperation(DL, I.getType(), BinaryOperator::Xor, Args, Ret, LHS, - RHS, &I); - break; - case AtomicRMWInst::FAdd: - visitBinaryOperation(DL, I.getType(), BinaryOperator::FAdd, Args, Ret, LHS, - RHS, &I); - break; - case AtomicRMWInst::FSub: - visitBinaryOperation(DL, I.getType(), BinaryOperator::FSub, Args, Ret, LHS, - RHS, &I); - break; - case AtomicRMWInst::Max: - case AtomicRMWInst::Min: - case AtomicRMWInst::UMax: - case AtomicRMWInst::UMin: - case AtomicRMWInst::Nand: - default: - break; - } - - if (direction & UP) { - TypeTree ptr = LHS.PurgeAnything() - .ShiftIndices(DL, /*start*/ 0, LoadSize, /*addOffset*/ 0) - .Only(-1, &I); - ptr.insert({-1}, BaseType::Pointer); - updateAnalysis(I.getOperand(0), ptr, &I); - updateAnalysis(I.getOperand(1), RHS, &I); - } - - if (direction & DOWN) { - if (Ret[{-1}] == BaseType::Anything && LHS[{-1}] != BaseType::Anything) - Ret = LHS; - if (I.getType()->isIntOrIntVectorTy() && Ret[{-1}] == BaseType::Anything) { - if (mustRemainInteger(&I)) { - Ret = TypeTree(BaseType::Integer).Only(-1, &I); - } - } - updateAnalysis(&I, Ret, &I); - } -} - -void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T, - llvm::Instruction::BinaryOps Opcode, - Value *Args[2], TypeTree &Ret, - TypeTree &LHS, TypeTree &RHS, - Instruction *origin) { - if (Opcode == BinaryOperator::FAdd || Opcode == BinaryOperator::FSub || - Opcode == BinaryOperator::FMul || Opcode == BinaryOperator::FDiv || - Opcode == BinaryOperator::FRem) { - auto ty = T->getScalarType(); - assert(ty->isFloatingPointTy()); - ConcreteType dt(ty); - if (direction & UP) { - bool LegalOr = true; - auto Data = TypeTree(dt).Only(-1, nullptr); - LHS.checkedOrIn(Data, /*PointerIntSame*/ false, LegalOr); - if (CustomErrorHandler && !LegalOr) { - std::string str; - raw_string_ostream ss(str); - ss << "Illegal updateAnalysis prev:" << LHS.str() - << " new: " << Data.str() << "\n"; - ss << "val: " << *Args[0]; - ss << "origin: " << *origin; - CustomErrorHandler(str.c_str(), wrap(Args[0]), - ErrorType::IllegalTypeAnalysis, (void *)this, - wrap(origin), nullptr); - } - RHS.checkedOrIn(Data, /*PointerIntSame*/ false, LegalOr); - if (CustomErrorHandler && !LegalOr) { - std::string str; - raw_string_ostream ss(str); - ss << "Illegal updateAnalysis prev:" << RHS.str() - << " new: " << Data.str() << "\n"; - ss << "val: " << *Args[1]; - ss << "origin: " << *origin; - CustomErrorHandler(str.c_str(), wrap(Args[1]), - ErrorType::IllegalTypeAnalysis, (void *)this, - wrap(origin), nullptr); - } - } - if (direction & DOWN) - Ret |= TypeTree(dt).Only(-1, nullptr); - } else { - auto size = (dl.getTypeSizeInBits(T) + 7) / 8; - auto AnalysisLHS = LHS.Data0(); - auto AnalysisRHS = RHS.Data0(); - auto AnalysisRet = Ret.Data0(); - - switch (Opcode) { - case BinaryOperator::Sub: - // ptr - ptr => int and int - int => int; thus int = a - b says only that - // these are equal ptr - int => ptr and int - ptr => ptr; thus - // howerver we do not want to propagate underlying ptr types since it's - // legal to subtract unrelated pointer - if (direction & UP) { - if (AnalysisRet[{}] == BaseType::Integer) { - LHS |= TypeTree(AnalysisRHS[{}]).PurgeAnything().Only(-1, nullptr); - RHS |= TypeTree(AnalysisLHS[{}]).PurgeAnything().Only(-1, nullptr); - } - if (AnalysisRet[{}] == BaseType::Pointer) { - if (AnalysisLHS[{}] == BaseType::Pointer) { - RHS |= TypeTree(BaseType::Integer).Only(-1, nullptr); - } - if (AnalysisRHS[{}] == BaseType::Integer) { - LHS |= TypeTree(BaseType::Pointer).Only(-1, nullptr); - } - } - } - break; - - case BinaryOperator::Add: - case BinaryOperator::Mul: - // if a + b or a * b == int, then a and b must be ints - if (direction & UP) { - if (AnalysisRet[{}] == BaseType::Integer) { - LHS.orIn({-1}, BaseType::Integer); - RHS.orIn({-1}, BaseType::Integer); - } - } - break; - - case BinaryOperator::Xor: - if (direction & UP) - for (int i = 0; i < 2; ++i) { - Type *FT = nullptr; - if (!(FT = Ret.IsAllFloat(size, dl))) - continue; - // If ^ against 0b10000000000, the result is a float - bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl); - if (validXor) { - bool Legal = true; - ((i == 0) ? RHS : LHS) - .checkedOrIn(TypeTree(FT).Only(-1, nullptr), - /*pointerintsame*/ false, Legal); - - if (!Legal) { - std::string str; - raw_string_ostream ss(str); - if (!CustomErrorHandler) { - llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; - llvm::errs() << *fntypeinfo.Function << "\n"; - dump(ss); - } - ss << "Illegal updateBinop (xor up) Analysis " << *origin << "\n"; - ss << " (i=" << i << ") " << (i == 0 ? "RHS" : "LHS") << " " - << ((i == 0) ? RHS : LHS).str() << " FT from ret: " << *FT - << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(origin), - ErrorType::IllegalTypeAnalysis, (void *)this, - wrap(origin), nullptr); - } - EmitFailure("IllegalUpdateAnalysis", origin->getDebugLoc(), - origin, ss.str()); - report_fatal_error("Performed illegal updateAnalysis"); - } - } - } - break; - case BinaryOperator::Or: - for (int i = 0; i < 2; ++i) { - Type *FT = nullptr; - if (!(FT = Ret.IsAllFloat(size, dl))) - continue; - // If | against a number only or'ing the exponent, the result is a float - bool validXor = false; - if (auto CIT = dyn_cast_or_null(Args[i])) { - if (dl.getTypeSizeInBits(FT) != dl.getTypeSizeInBits(CIT->getType())) - continue; - auto CI = CIT->getValue(); -#if LLVM_VERSION_MAJOR > 16 - if (CI.isZero()) -#else - if (CI.isNullValue()) -#endif - { - validXor = true; - } else if ( - !CI.isNegative() && - ((FT->isFloatTy() -#if LLVM_VERSION_MAJOR > 16 - && (CI & ~0b01111111100000000000000000000000ULL).isZero() -#else - && (CI & ~0b01111111100000000000000000000000ULL).isNullValue() -#endif - ) || - (FT->isDoubleTy() -#if LLVM_VERSION_MAJOR > 16 - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isZero() -#else - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isNullValue() -#endif - ))) { - validXor = true; - } - } else if (auto CV = dyn_cast_or_null(Args[i])) { - validXor = true; - if (dl.getTypeSizeInBits(FT) != - dl.getTypeSizeInBits(CV->getOperand(i)->getType())) - continue; - for (size_t i = 0, end = CV->getNumOperands(); i < end; ++i) { - auto CI = dyn_cast(CV->getOperand(i))->getValue(); - -#if LLVM_VERSION_MAJOR > 16 - if (CI.isZero()) -#else - if (CI.isNullValue()) -#endif - { - } else if ( - !CI.isNegative() && - ((FT->isFloatTy() -#if LLVM_VERSION_MAJOR > 16 - && (CI & ~0b01111111100000000000000000000000ULL).isZero() -#else - && (CI & ~0b01111111100000000000000000000000ULL).isNullValue() -#endif - ) || - (FT->isDoubleTy() -#if LLVM_VERSION_MAJOR > 16 - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isZero() -#else - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isNullValue() -#endif - ))) { - } else - validXor = false; - } - } else if (auto CV = dyn_cast_or_null(Args[i])) { - validXor = true; - if (dl.getTypeSizeInBits(FT) != - dl.getTypeSizeInBits(CV->getElementType())) - continue; - for (size_t i = 0, end = CV->getNumElements(); i < end; ++i) { - auto CI = CV->getElementAsAPInt(i); -#if LLVM_VERSION_MAJOR > 16 - if (CI.isZero()) -#else - if (CI.isNullValue()) -#endif - { - } else if ( - !CI.isNegative() && - ((FT->isFloatTy() -#if LLVM_VERSION_MAJOR > 16 - && (CI & ~0b01111111100000000000000000000000ULL).isZero() -#else - && (CI & ~0b01111111100000000000000000000000ULL).isNullValue() -#endif - ) || - (FT->isDoubleTy() -#if LLVM_VERSION_MAJOR > 16 - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isZero() -#else - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isNullValue() -#endif - ))) { - } else - validXor = false; - } - } - if (validXor) { - ((i == 0) ? RHS : LHS) |= TypeTree(FT).Only(-1, nullptr); - } - } - break; - default: - break; - } - - if (direction & DOWN) { - TypeTree Result = AnalysisLHS; - bool Legal = true; - Result.binopIn(Legal, AnalysisRHS, Opcode); - if (!Legal) { - std::string str; - raw_string_ostream ss(str); - if (!CustomErrorHandler) { - llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; - llvm::errs() << *fntypeinfo.Function << "\n"; - dump(ss); - } - ss << "Illegal updateBinop Analysis " << *origin << "\n"; - ss << "Illegal binopIn(down): " << Opcode << " lhs: " << Result.str() - << " rhs: " << AnalysisRHS.str() << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(origin), - ErrorType::IllegalTypeAnalysis, (void *)this, - wrap(origin), nullptr); - } - EmitFailure("IllegalUpdateAnalysis", origin->getDebugLoc(), origin, - ss.str()); - report_fatal_error("Performed illegal updateAnalysis"); - } - if (Opcode == BinaryOperator::And) { - for (int i = 0; i < 2; ++i) { - if (Args[i]) - for (auto andval : - fntypeinfo.knownIntegralValues(Args[i], DT, intseen, SE)) { - if (andval <= 16 && andval >= 0) { - Result = TypeTree(BaseType::Integer); - } else if (andval < 0 && andval >= -64) { - // If a small negative number, this just masks off the lower - // bits in this case we can say that this is the same as the - // other operand - Result = (i == 0 ? AnalysisRHS : AnalysisLHS); - } - } - // If we and a constant against an integer, the result remains an - // integer - if (Args[i] && isa(Args[i]) && - (i == 0 ? AnalysisRHS : AnalysisLHS).Inner0() == - BaseType::Integer) { - Result = TypeTree(BaseType::Integer); - } - } - } else if (Opcode == BinaryOperator::Add || - Opcode == BinaryOperator::Sub) { - for (int i = 0; i < 2; ++i) { - if (i == 1 || Opcode == BinaryOperator::Add) - if (auto CI = dyn_cast_or_null(Args[i])) { - if (CI->isNegative() || CI->isZero() || - CI->getLimitedValue() <= 4096) { - // If add/sub with zero, small, or negative number, the result - // is equal to the type of the other operand (and we don't need - // to assume this was an "anything") - Result = (i == 0 ? AnalysisRHS : AnalysisLHS); - } - } - } - } else if (Opcode == BinaryOperator::Mul) { - for (int i = 0; i < 2; ++i) { - // If we mul a constant against an integer, the result remains an - // integer - if (Args[i] && isa(Args[i]) && - (i == 0 ? AnalysisRHS : AnalysisLHS)[{}] == BaseType::Integer) { - Result = TypeTree(BaseType::Integer); - } - } - } else if (Opcode == BinaryOperator::URem) { - if (auto CI = dyn_cast_or_null(Args[1])) { - // If rem with a small integer, the result is also a small integer - if (CI->getLimitedValue() <= 4096) { - Result = TypeTree(BaseType::Integer); - } - } - } else if (Opcode == BinaryOperator::Xor) { - for (int i = 0; i < 2; ++i) { - Type *FT; - if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size, dl))) - continue; - // If ^ against 0b10000000000, the result is a float - bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl); - if (validXor) { - Result = ConcreteType(FT); - } - } - } else if (Opcode == BinaryOperator::Or) { - for (int i = 0; i < 2; ++i) { - Type *FT; - if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size, dl))) - continue; - // If & against 0b10000000000, the result is a float - bool validXor = false; - if (auto CIT = dyn_cast_or_null(Args[i])) { - if (dl.getTypeSizeInBits(FT) != - dl.getTypeSizeInBits(CIT->getType())) - continue; - auto CI = CIT->getValue(); -#if LLVM_VERSION_MAJOR > 16 - if (CI.isZero()) -#else - if (CI.isNullValue()) -#endif - { - validXor = true; - } else if ( - !CI.isNegative() && - ((FT->isFloatTy() -#if LLVM_VERSION_MAJOR > 16 - && (CI & ~0b01111111100000000000000000000000ULL).isZero() -#else - && (CI & ~0b01111111100000000000000000000000ULL).isNullValue() -#endif - ) || - (FT->isDoubleTy() -#if LLVM_VERSION_MAJOR > 16 - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isZero() -#else - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isNullValue() -#endif - ))) { - validXor = true; - } - } else if (auto CV = dyn_cast_or_null(Args[i])) { - validXor = true; - if (dl.getTypeSizeInBits(FT) != - dl.getTypeSizeInBits(CV->getOperand(i)->getType())) - continue; - for (size_t i = 0, end = CV->getNumOperands(); i < end; ++i) { - auto CI = dyn_cast(CV->getOperand(i))->getValue(); -#if LLVM_VERSION_MAJOR > 16 - if (CI.isZero()) -#else - if (CI.isNullValue()) -#endif - { - } else if ( - !CI.isNegative() && - ((FT->isFloatTy() -#if LLVM_VERSION_MAJOR > 16 - && (CI & ~0b01111111100000000000000000000000ULL).isZero() -#else - && - (CI & ~0b01111111100000000000000000000000ULL).isNullValue() -#endif - ) || - (FT->isDoubleTy() -#if LLVM_VERSION_MAJOR > 16 - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isZero() -#else - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isNullValue() -#endif - ))) { - } else - validXor = false; - } - } else if (auto CV = dyn_cast_or_null(Args[i])) { - validXor = true; - if (dl.getTypeSizeInBits(FT) != - dl.getTypeSizeInBits(CV->getElementType())) - continue; - for (size_t i = 0, end = CV->getNumElements(); i < end; ++i) { - auto CI = CV->getElementAsAPInt(i); -#if LLVM_VERSION_MAJOR > 16 - if (CI.isZero()) -#else - if (CI.isNullValue()) -#endif - { - } else if ( - !CI.isNegative() && - ((FT->isFloatTy() -#if LLVM_VERSION_MAJOR > 16 - && (CI & ~0b01111111100000000000000000000000ULL).isZero() -#else - && - (CI & ~0b01111111100000000000000000000000ULL).isNullValue() -#endif - ) || - (FT->isDoubleTy() -#if LLVM_VERSION_MAJOR > 16 - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isZero() -#else - && - (CI & - ~0b0111111111110000000000000000000000000000000000000000000000000000ULL) - .isNullValue() -#endif - ))) { - } else - validXor = false; - } - } - if (validXor) { - Result = ConcreteType(FT); - } - } - } - - Ret = Result.Only(-1, nullptr); - } - } -} -void TypeAnalyzer::visitBinaryOperator(BinaryOperator &I) { - Value *Args[2] = {I.getOperand(0), I.getOperand(1)}; - TypeTree Ret = getAnalysis(&I); - TypeTree LHS = getAnalysis(I.getOperand(0)); - TypeTree RHS = getAnalysis(I.getOperand(1)); - auto &DL = I.getParent()->getParent()->getParent()->getDataLayout(); - visitBinaryOperation(DL, I.getType(), I.getOpcode(), Args, Ret, LHS, RHS, &I); - - if (direction & UP) { - updateAnalysis(I.getOperand(0), LHS, &I); - updateAnalysis(I.getOperand(1), RHS, &I); - } - - if (direction & DOWN) { - if (I.getType()->isIntOrIntVectorTy() && Ret[{-1}] == BaseType::Anything) { - if (mustRemainInteger(&I)) { - Ret = TypeTree(BaseType::Integer).Only(-1, &I); - } - } - updateAnalysis(&I, Ret, &I); - } -} - -void TypeAnalyzer::visitMemTransferInst(llvm::MemTransferInst &MTI) { - visitMemTransferCommon(MTI); -} - -void TypeAnalyzer::visitMemTransferCommon(llvm::CallBase &MTI) { - if (MTI.getType()->isIntegerTy()) { - updateAnalysis(&MTI, TypeTree(BaseType::Integer).Only(-1, &MTI), &MTI); - } - - if (!(direction & UP)) - return; - - // If memcpy / memmove of pointer, we can propagate type information from src - // to dst up to the length and vice versa - size_t sz = 1; - for (auto val : - fntypeinfo.knownIntegralValues(MTI.getArgOperand(2), DT, intseen, SE)) { - if (val >= 0) { - sz = max(sz, (size_t)val); - } - } - - auto &dl = MTI.getParent()->getParent()->getParent()->getDataLayout(); - TypeTree res = getAnalysis(MTI.getArgOperand(0)) - .PurgeAnything() - .Data0() - .ShiftIndices(dl, 0, sz, 0); - TypeTree res2 = getAnalysis(MTI.getArgOperand(1)) - .PurgeAnything() - .Data0() - .ShiftIndices(dl, 0, sz, 0); - - bool Legal = true; - res.checkedOrIn(res2, /*PointerIntSame*/ false, Legal); - if (!Legal) { - std::string str; - raw_string_ostream ss(str); - if (!CustomErrorHandler) { - llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; - llvm::errs() << *fntypeinfo.Function << "\n"; - dump(ss); - } - ss << "Illegal updateMemTransfer Analysis " << MTI << "\n"; - ss << "Illegal orIn: " << res.str() << " right: " << res2.str() << "\n"; - ss << *MTI.getArgOperand(0) << " " - << getAnalysis(MTI.getArgOperand(0)).str() << "\n"; - ss << *MTI.getArgOperand(1) << " " - << getAnalysis(MTI.getArgOperand(1)).str() << "\n"; - - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(&MTI), - ErrorType::IllegalTypeAnalysis, (void *)this, - wrap(&MTI), nullptr); - } - EmitFailure("IllegalUpdateAnalysis", MTI.getDebugLoc(), &MTI, ss.str()); - report_fatal_error("Performed illegal updateAnalysis"); - } - res.insert({}, BaseType::Pointer); - res = res.Only(-1, &MTI); - updateAnalysis(MTI.getArgOperand(0), res, &MTI); - updateAnalysis(MTI.getArgOperand(1), res, &MTI); -#if LLVM_VERSION_MAJOR >= 14 - for (unsigned i = 2; i < MTI.arg_size(); ++i) -#else - for (unsigned i = 2; i < MTI.getNumArgOperands(); ++i) -#endif - { - updateAnalysis(MTI.getArgOperand(i), - TypeTree(BaseType::Integer).Only(-1, &MTI), &MTI); - } -} - -void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { - switch (I.getIntrinsicID()) { - case Intrinsic::ctpop: - case Intrinsic::ctlz: - case Intrinsic::cttz: - case Intrinsic::nvvm_read_ptx_sreg_tid_x: - case Intrinsic::nvvm_read_ptx_sreg_tid_y: - case Intrinsic::nvvm_read_ptx_sreg_tid_z: - case Intrinsic::nvvm_read_ptx_sreg_ntid_x: - case Intrinsic::nvvm_read_ptx_sreg_ntid_y: - case Intrinsic::nvvm_read_ptx_sreg_ntid_z: - case Intrinsic::nvvm_read_ptx_sreg_ctaid_x: - case Intrinsic::nvvm_read_ptx_sreg_ctaid_y: - case Intrinsic::nvvm_read_ptx_sreg_ctaid_z: - case Intrinsic::nvvm_read_ptx_sreg_nctaid_x: - case Intrinsic::nvvm_read_ptx_sreg_nctaid_y: - case Intrinsic::nvvm_read_ptx_sreg_nctaid_z: - case Intrinsic::nvvm_read_ptx_sreg_warpsize: - case Intrinsic::amdgcn_workitem_id_x: - case Intrinsic::amdgcn_workitem_id_y: - case Intrinsic::amdgcn_workitem_id_z: - // No direction check as always valid - updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); - return; - - case Intrinsic::nvvm_barrier0_popc: - case Intrinsic::nvvm_barrier0_and: - case Intrinsic::nvvm_barrier0_or: - // No direction check as always valid - updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); - updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I), - &I); - return; - - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col: - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row: - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col: - case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row: - case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col: - case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row: - case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: { - TypeTree TT; - TT.insert({-1}, BaseType::Pointer); - TT.insert({-1, 0}, Type::getFloatTy(I.getContext())); - updateAnalysis(I.getOperand(0), TT, &I); - for (int i = 1; i <= 9; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I), - &I); - return; - } - - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col: - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row: - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col: - case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row: - case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col: - case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row: - case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: { - TypeTree TT; - TT.insert({-1}, BaseType::Pointer); - TT.insert({-1, 0}, Type::getHalfTy(I.getContext())); - updateAnalysis(I.getOperand(0), TT, &I); - for (int i = 1; i <= 9; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - return; - } - - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: { - TypeTree TT; - TT.insert({-1}, BaseType::Pointer); - TT.insert({-1, 0}, Type::getFloatTy(I.getContext())); - updateAnalysis(I.getOperand(0), TT, &I); - updateAnalysis( - &I, - TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I), - &I); - return; - } - - case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: { - TypeTree TT; - TT.insert({-1}, BaseType::Pointer); - TT.insert({-1, 0}, Type::getHalfTy(I.getContext())); - updateAnalysis(I.getOperand(0), TT, &I); - updateAnalysis( - &I, - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - return; - } - - case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row: - case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: - case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row: - case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride: - case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col: - case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride: - case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col: - case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride: - case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row: - case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride: - case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row: - case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride: - case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride: - case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row: - case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col: - case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride: - case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride: - case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: - case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col: - case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride: - case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row: - case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: { - // TODO - return; - } - - case Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_f16_f16: - case Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_f16_f16: - case Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_f16_f16: - case Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16: - case Intrinsic::nvvm_wmma_m32n8k16_mma_col_col_f16_f16: - case Intrinsic::nvvm_wmma_m32n8k16_mma_col_row_f16_f16: - case Intrinsic::nvvm_wmma_m32n8k16_mma_row_col_f16_f16: - case Intrinsic::nvvm_wmma_m32n8k16_mma_row_row_f16_f16: - case Intrinsic::nvvm_wmma_m8n32k16_mma_col_col_f16_f16: - case Intrinsic::nvvm_wmma_m8n32k16_mma_col_row_f16_f16: - case Intrinsic::nvvm_wmma_m8n32k16_mma_row_col_f16_f16: - case Intrinsic::nvvm_wmma_m8n32k16_mma_row_row_f16_f16: { - for (int i = 0; i < 16; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - for (int i = 16; i < 16 + 8; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - updateAnalysis( - &I, - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - return; - } - - case Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_f16_f32: - case Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_f16_f32: - case Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_f16_f32: - case Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f32: - case Intrinsic::nvvm_wmma_m32n8k16_mma_col_col_f16_f32: - case Intrinsic::nvvm_wmma_m32n8k16_mma_col_row_f16_f32: - case Intrinsic::nvvm_wmma_m32n8k16_mma_row_col_f16_f32: - case Intrinsic::nvvm_wmma_m32n8k16_mma_row_row_f16_f32: - case Intrinsic::nvvm_wmma_m8n32k16_mma_col_col_f16_f32: - case Intrinsic::nvvm_wmma_m8n32k16_mma_col_row_f16_f32: - case Intrinsic::nvvm_wmma_m8n32k16_mma_row_col_f16_f32: - case Intrinsic::nvvm_wmma_m8n32k16_mma_row_row_f16_f32: { - for (int i = 0; i < 16; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - for (int i = 16; i < 16 + 8; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I), - &I); - updateAnalysis( - &I, - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - return; - } - - case Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_f32_f16: - case Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_f32_f16: - case Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_f32_f16: - case Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f16: - case Intrinsic::nvvm_wmma_m32n8k16_mma_col_col_f32_f16: - case Intrinsic::nvvm_wmma_m32n8k16_mma_col_row_f32_f16: - case Intrinsic::nvvm_wmma_m32n8k16_mma_row_col_f32_f16: - case Intrinsic::nvvm_wmma_m32n8k16_mma_row_row_f32_f16: - case Intrinsic::nvvm_wmma_m8n32k16_mma_col_col_f32_f16: - case Intrinsic::nvvm_wmma_m8n32k16_mma_col_row_f32_f16: - case Intrinsic::nvvm_wmma_m8n32k16_mma_row_col_f32_f16: - case Intrinsic::nvvm_wmma_m8n32k16_mma_row_row_f32_f16: { - for (int i = 0; i < 16; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - for (int i = 16; i < 16 + 8; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - updateAnalysis( - &I, - TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I), - &I); - return; - } - - case Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_f32_f32: - case Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_f32_f32: - case Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_f32_f32: - case Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32: - case Intrinsic::nvvm_wmma_m32n8k16_mma_col_col_f32_f32: - case Intrinsic::nvvm_wmma_m32n8k16_mma_col_row_f32_f32: - case Intrinsic::nvvm_wmma_m32n8k16_mma_row_col_f32_f32: - case Intrinsic::nvvm_wmma_m32n8k16_mma_row_row_f32_f32: - case Intrinsic::nvvm_wmma_m8n32k16_mma_col_col_f32_f32: - case Intrinsic::nvvm_wmma_m8n32k16_mma_col_row_f32_f32: - case Intrinsic::nvvm_wmma_m8n32k16_mma_row_col_f32_f32: - case Intrinsic::nvvm_wmma_m8n32k16_mma_row_row_f32_f32: { - for (int i = 0; i < 16; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I), - &I); - for (int i = 16; i < 16 + 8; i++) - updateAnalysis( - I.getOperand(i), - TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I), - &I); - updateAnalysis( - &I, - TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I), - &I); - return; - } - -#if LLVM_VERSION_MAJOR < 20 - case Intrinsic::nvvm_ldg_global_i: - case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: -#endif - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: { - auto &DL = I.getParent()->getParent()->getParent()->getDataLayout(); - auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8; - - if (direction & UP) { - TypeTree ptr(BaseType::Pointer); - ptr |= getAnalysis(&I).PurgeAnything().ShiftIndices( - DL, /*start*/ 0, LoadSize, /*addOffset*/ 0); - updateAnalysis(I.getOperand(0), ptr.Only(-1, &I), &I); - } - if (direction & DOWN) - updateAnalysis(&I, getAnalysis(I.getOperand(0)).Lookup(LoadSize, DL), &I); - return; - } - - case Intrinsic::log: - case Intrinsic::log2: - case Intrinsic::log10: - case Intrinsic::exp: - case Intrinsic::exp2: - case Intrinsic::sin: - case Intrinsic::cos: -#if LLVM_VERSION_MAJOR >= 19 - case Intrinsic::sinh: - case Intrinsic::cosh: - case Intrinsic::tanh: -#endif - case Intrinsic::floor: - case Intrinsic::ceil: - case Intrinsic::trunc: - case Intrinsic::rint: - case Intrinsic::nearbyint: - case Intrinsic::round: - case Intrinsic::sqrt: -#if LLVM_VERSION_MAJOR >= 21 - case Intrinsic::nvvm_fabs: - case Intrinsic::nvvm_fabs_ftz: -#else - case Intrinsic::nvvm_fabs_f: - case Intrinsic::nvvm_fabs_d: - case Intrinsic::nvvm_fabs_ftz_f: -#endif - case Intrinsic::fabs: - // No direction check as always valid - updateAnalysis( - &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I), - &I); - // No direction check as always valid - updateAnalysis( - I.getOperand(0), - TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType())) - .Only(-1, &I), - &I); - return; - - case Intrinsic::fmuladd: - case Intrinsic::fma: - // No direction check as always valid - updateAnalysis( - &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I), - &I); - // No direction check as always valid - updateAnalysis( - I.getOperand(0), - TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType())) - .Only(-1, &I), - &I); - // No direction check as always valid - updateAnalysis( - I.getOperand(1), - TypeTree(ConcreteType(I.getOperand(1)->getType()->getScalarType())) - .Only(-1, &I), - &I); - // No direction check as always valid - updateAnalysis( - I.getOperand(2), - TypeTree(ConcreteType(I.getOperand(2)->getType()->getScalarType())) - .Only(-1, &I), - &I); - return; - - case Intrinsic::powi: - // No direction check as always valid - updateAnalysis( - &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I), - &I); - // No direction check as always valid - updateAnalysis( - I.getOperand(0), - TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType())) - .Only(-1, &I), - &I); - // No direction check as always valid - updateAnalysis(I.getOperand(1), TypeTree(BaseType::Integer).Only(-1, &I), - &I); - return; - -#if LLVM_VERSION_MAJOR >= 12 - case Intrinsic::vector_reduce_fadd: - case Intrinsic::vector_reduce_fmul: -#else - case Intrinsic::experimental_vector_reduce_v2_fadd: - case Intrinsic::experimental_vector_reduce_v2_fmul: -#endif - case Intrinsic::copysign: - case Intrinsic::maxnum: - case Intrinsic::minnum: -#if LLVM_VERSION_MAJOR >= 15 - case Intrinsic::maximum: - case Intrinsic::minimum: -#endif - case Intrinsic::nvvm_fmax_f: - case Intrinsic::nvvm_fmax_d: - case Intrinsic::nvvm_fmax_ftz_f: - case Intrinsic::nvvm_fmin_f: - case Intrinsic::nvvm_fmin_d: - case Intrinsic::nvvm_fmin_ftz_f: - case Intrinsic::pow: - // No direction check as always valid - updateAnalysis( - &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I), - &I); - // No direction check as always valid - updateAnalysis( - I.getOperand(0), - TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType())) - .Only(-1, &I), - &I); - // No direction check as always valid - updateAnalysis( - I.getOperand(1), - TypeTree(ConcreteType(I.getOperand(1)->getType()->getScalarType())) - .Only(-1, &I), - &I); - return; -#if LLVM_VERSION_MAJOR >= 12 - case Intrinsic::smax: - case Intrinsic::smin: - case Intrinsic::umax: - case Intrinsic::umin: - if (direction & UP) { - auto returnType = getAnalysis(&I)[{-1}]; - if (returnType == BaseType::Integer || returnType == BaseType::Pointer) { - updateAnalysis(I.getOperand(0), TypeTree(returnType).Only(-1, &I), &I); - updateAnalysis(I.getOperand(1), TypeTree(returnType).Only(-1, &I), &I); - } - } - if (direction & DOWN) { - auto opType0 = getAnalysis(I.getOperand(0))[{-1}]; - auto opType1 = getAnalysis(I.getOperand(1))[{-1}]; - if (opType0 == opType1 && - (opType0 == BaseType::Integer || opType0 == BaseType::Pointer)) { - updateAnalysis(&I, TypeTree(opType0).Only(-1, &I), &I); - } else if (opType0 == BaseType::Integer && - opType1 == BaseType::Anything) { - updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); - } else if (opType1 == BaseType::Integer && - opType0 == BaseType::Anything) { - updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); - } - } - return; -#endif - case Intrinsic::umul_with_overflow: - case Intrinsic::smul_with_overflow: - case Intrinsic::ssub_with_overflow: - case Intrinsic::usub_with_overflow: - case Intrinsic::sadd_with_overflow: - case Intrinsic::uadd_with_overflow: { - // val, bool - auto analysis = getAnalysis(&I).Data0(); - - BinaryOperator::BinaryOps opcode; - // TODO update to use better rules in regular binop - switch (I.getIntrinsicID()) { - case Intrinsic::ssub_with_overflow: - case Intrinsic::usub_with_overflow: { - // TODO propagate this info - // ptr - ptr => int and int - int => int; thus int = a - b says only that - // these are equal ptr - int => ptr and int - ptr => ptr; thus - analysis = ConcreteType(BaseType::Unknown); - opcode = BinaryOperator::Sub; - break; - } - - case Intrinsic::smul_with_overflow: - case Intrinsic::umul_with_overflow: { - opcode = BinaryOperator::Mul; - // if a + b or a * b == int, then a and b must be ints - analysis = analysis.JustInt(); - break; - } - case Intrinsic::sadd_with_overflow: - case Intrinsic::uadd_with_overflow: { - opcode = BinaryOperator::Add; - // if a + b or a * b == int, then a and b must be ints - analysis = analysis.JustInt(); - break; - } - default: - llvm_unreachable("unknown binary operator"); - } - - // TODO update with newer binop protocol (see binop) - if (direction & UP) - updateAnalysis(I.getOperand(0), analysis.Only(-1, &I), &I); - if (direction & UP) - updateAnalysis(I.getOperand(1), analysis.Only(-1, &I), &I); - - TypeTree vd = getAnalysis(I.getOperand(0)).Data0(); - bool Legal = true; - vd.binopIn(Legal, getAnalysis(I.getOperand(1)).Data0(), opcode); - if (!Legal) { - std::string str; - raw_string_ostream ss(str); - if (!CustomErrorHandler) { - llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; - llvm::errs() << *fntypeinfo.Function << "\n"; - dump(ss); - } - ss << "Illegal updateBinopIntr Analysis " << I << "\n"; - ss << "Illegal binopIn(intr): " << I << " lhs: " << vd.str() - << " rhs: " << getAnalysis(I.getOperand(1)).str() << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(&I), - ErrorType::IllegalTypeAnalysis, (void *)this, - wrap(&I), nullptr); - } - EmitFailure("IllegalUpdateAnalysis", I.getDebugLoc(), &I, ss.str()); - report_fatal_error("Performed illegal updateAnalysis"); - } - auto &dl = I.getParent()->getParent()->getParent()->getDataLayout(); - int sz = (dl.getTypeSizeInBits(I.getOperand(0)->getType()) + 7) / 8; - TypeTree overall = vd.Only(-1, &I).ShiftIndices(dl, 0, sz, 0); - - int sz2 = (dl.getTypeSizeInBits(I.getType()) + 7) / 8; - auto btree = TypeTree(BaseType::Integer) - .Only(-1, &I) - .ShiftIndices(dl, 0, sz2 - sz, sz); - overall |= btree; - - if (direction & DOWN) - updateAnalysis(&I, overall, &I); - return; - } - default: - return; - } -} - -/// This template class is defined to take the templated type T -/// update the analysis of the first argument (val) to be type T -/// As such, below we have several template specializations -/// to convert various c/c++ to TypeAnalysis types -template struct TypeHandler {}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TA.updateAnalysis( - val, - TypeTree(ConcreteType(Type::getDoubleTy(call.getContext()))) - .Only(-1, &call), - &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TA.updateAnalysis( - val, - TypeTree(ConcreteType(Type::getFloatTy(call.getContext()))) - .Only(-1, &call), - &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TA.updateAnalysis( - val, - TypeTree(ConcreteType(Type::getX86_FP80Ty(call.getContext()))) - .Only(-1, &call), - &call); - } -}; - -#if defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__) -template <> struct TypeHandler<__float128> { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TA.updateAnalysis( - val, - TypeTree(ConcreteType(Type::getFP128Ty(call.getContext()))) - .Only(-1, &call), - &call); - } -}; -#endif - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(Type::getDoubleTy(call.getContext())).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(Type::getFloatTy(call.getContext())).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = - TypeTree(Type::getX86_FP80Ty(call.getContext())).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -#if defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__) -template <> struct TypeHandler<__float128 *> { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(Type::getFP128Ty(call.getContext())).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; -#endif - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {} -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template <> struct TypeHandler { - static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) { - TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call); - vd |= TypeTree(BaseType::Pointer); - TA.updateAnalysis(val, vd.Only(-1, &call), &call); - } -}; - -template struct FunctionArgumentIterator { - static void analyzeFuncTypesHelper(unsigned idx, CallBase &call, - TypeAnalyzer &TA) {} -}; - -template -struct FunctionArgumentIterator { - static void analyzeFuncTypesHelper(unsigned idx, CallBase &call, - TypeAnalyzer &TA) { - TypeHandler::analyzeType(call.getOperand(idx), call, TA); - FunctionArgumentIterator::analyzeFuncTypesHelper(idx + 1, call, - TA); - } -}; - -template -void analyzeFuncTypesNoFn(CallBase &call, TypeAnalyzer &TA) { - TypeHandler::analyzeType(&call, call, TA); - FunctionArgumentIterator::analyzeFuncTypesHelper(0, call, TA); -} - -template -void analyzeFuncTypes(RT (*fn)(Args...), CallBase &call, TypeAnalyzer &TA) { - analyzeFuncTypesNoFn(call, TA); -} - -void analyzeIntelSubscriptIntrinsic(IntrinsicInst &II, TypeAnalyzer &TA) { - assert(isIntelSubscriptIntrinsic(II)); -#if LLVM_VERSION_MAJOR >= 14 - assert(II.arg_size() == 5); -#else - assert(II.getNumArgOperands() == 5); -#endif - - constexpr size_t idxArgsIndices[4] = {0, 1, 2, 4}; - constexpr size_t ptrArgIndex = 3; - - // Update analysis of index parameters - - if (TA.direction & TypeAnalyzer::UP) { - for (auto i : idxArgsIndices) { - auto idx = II.getOperand(i); - TA.updateAnalysis(idx, TypeTree(BaseType::Integer).Only(-1, &II), &II); - } - } - - // Update analysis of ptr parameter - - auto &DL = TA.fntypeinfo.Function->getParent()->getDataLayout(); - auto pointerAnalysis = TA.getAnalysis(II.getOperand(ptrArgIndex)); - - if (TA.direction & TypeAnalyzer::DOWN) { - bool legal = true; - auto keepMinus = pointerAnalysis.KeepMinusOne(legal); - if (!legal) { - if (CustomErrorHandler) - CustomErrorHandler("Could not keep minus one", wrap(&II), - ErrorType::IllegalTypeAnalysis, &TA, nullptr, - nullptr); - else { - TA.dump(); - llvm::errs() - << " could not perform minus one for llvm.intel.subscript'd: " << II - << "\n"; - } - } - TA.updateAnalysis(&II, keepMinus, &II); - TA.updateAnalysis(&II, TypeTree(pointerAnalysis.Inner0()).Only(-1, &II), - &II); - } - - if (TA.direction & TypeAnalyzer::UP) { - TA.updateAnalysis(II.getOperand(ptrArgIndex), - TypeTree(TA.getAnalysis(&II).Inner0()).Only(-1, &II), - &II); - } - - SmallVector, 4> idnext; - // The first operand is used to denote the axis of a multidimensional array, - // but it is not used for address calculation, and so we skip it here. - constexpr size_t offsetCalculationIndices[3] = {1, 2, 4}; - for (auto i : offsetCalculationIndices) { - auto idx = II.getOperand(i); - auto iset = TA.knownIntegralValues(idx); - std::set vset; - for (auto i : iset) { - // Don't consider negative indices of llvm.intel.subscript - if (i < 0) - continue; - vset.insert(i); - } - idnext.push_back(vset); - if (idnext.back().size() == 0) - return; - } - assert(idnext.size() != 0); - - TypeTree upTree; - TypeTree downTree; - - TypeTree intrinsicData0; - TypeTree pointerData0; - if (TA.direction & TypeAnalyzer::UP) - intrinsicData0 = TA.getAnalysis(&II).Data0(); - if (TA.direction & TypeAnalyzer::DOWN) - pointerData0 = pointerAnalysis.Data0(); - - bool firstLoop = true; - - for (auto vec : getSet(idnext, idnext.size() - 1)) { - auto baseIndex = vec[0]; - auto stride = vec[1]; - auto index = vec[2]; - - int offset = static_cast(stride * (index - baseIndex)); - if (offset < 0) { - continue; // The intrinsic doesn't handle negative offsets - } - - if (TA.direction & TypeAnalyzer::DOWN) { - auto shft = pointerData0.ShiftIndices(DL, /*init offset*/ offset, - /*max size*/ -1, /*newoffset*/ 0); - if (firstLoop) - downTree = shft; - else - downTree &= shft; - } - - if (TA.direction & TypeAnalyzer::UP) { - auto shft = - intrinsicData0.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1, - /*new offset*/ offset); - if (firstLoop) - upTree = shft; - else - upTree |= shft; - } - firstLoop = false; - } - if (TA.direction & TypeAnalyzer::DOWN) - TA.updateAnalysis(&II, downTree.Only(-1, &II), &II); - if (TA.direction & TypeAnalyzer::UP) - TA.updateAnalysis(II.getOperand(ptrArgIndex), upTree.Only(-1, &II), &II); -} - -void TypeAnalyzer::visitCallBase(CallBase &call) { - assert(fntypeinfo.KnownValues.size() == - fntypeinfo.Function->getFunctionType()->getNumParams()); - - if (auto iasm = dyn_cast(call.getCalledOperand())) { - // NO direction check as always valid - if (StringRef(iasm->getAsmString()).contains("cpuid")) { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : call.args()) -#else - for (auto &arg : call.arg_operands()) -#endif - { - updateAnalysis(arg, TypeTree(BaseType::Integer).Only(-1, &call), &call); - } - } - } - - if (call.hasFnAttr("enzyme_ta_norecur")) - return; - - Function *ci = getFunctionFromCall(&call); - - if (ci) { - if (ci->getAttributes().hasAttribute(AttributeList::FunctionIndex, - "enzyme_ta_norecur")) - return; - - StringRef funcName = getFuncNameFromCall(&call); - - auto blasMetaData = extractBLAS(funcName); - if (blasMetaData) { - BlasInfo blas = *blasMetaData; -#include "BlasTA.inc" - } - - // clang-format off - const char* NoTARecurStartsWith[] = { - "std::__u::basic_ostream>& std::__u::operator<<", - }; - // clang-format on - { - std::string demangledName = llvm::demangle(funcName.str()); - // replace all '> >' with '>>' - size_t start = 0; - while ((start = demangledName.find("> >", start)) != std::string::npos) { - demangledName.replace(start, 3, ">>"); - } - for (auto Name : NoTARecurStartsWith) - if (startsWith(demangledName, Name)) - return; - } - - // Manual TT specification is non-interprocedural and already handled once - // at the start. - - // When compiling Enzyme against standard LLVM, and not Intel's - // modified version of LLVM, the intrinsic `llvm.intel.subscript` is - // not fully understood by LLVM. One of the results of this is that the - // visitor dispatches to visitCallBase, rather than visitIntrinsicInst, when - // presented with the intrinsic - hence why we are handling it here. - if (startsWith(funcName, "llvm.intel.subscript")) { - assert(isa(call)); - analyzeIntelSubscriptIntrinsic(cast(call), *this); - return; - } - -#define CONSIDER(fn) \ - if (funcName == #fn) { \ - analyzeFuncTypes(::fn, call, *this); \ - return; \ - } - -#define CONSIDER2(fn, ...) \ - if (funcName == #fn) { \ - analyzeFuncTypesNoFn<__VA_ARGS__>(call, *this); \ - return; \ - } - - auto customrule = interprocedural.CustomRules.find(funcName); - if (customrule != interprocedural.CustomRules.end()) { - auto returnAnalysis = getAnalysis(&call); - SmallVector args; - SmallVector, 4> knownValues; -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : call.args()) -#else - for (auto &arg : call.arg_operands()) -#endif - { - args.push_back(getAnalysis(arg)); - knownValues.push_back( - fntypeinfo.knownIntegralValues((Value *)arg, DT, intseen, SE)); - } - - bool err = customrule->second(direction, returnAnalysis, args, - knownValues, &call, this); - if (err) { - Invalid = true; - return; - } - updateAnalysis(&call, returnAnalysis, &call); - size_t argnum = 0; -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : call.args()) -#else - for (auto &arg : call.arg_operands()) -#endif - { - updateAnalysis(arg, args[argnum], &call); - argnum++; - } - return; - } - - // All these are always valid => no direction check - // CONSIDER(malloc) - // TODO consider handling other allocation functions integer inputs - if (startsWith(funcName, "_ZN3std2io5stdio6_print") || - startsWith(funcName, "_ZN4core3fmt")) { - return; - } - /// GEMM - if (funcName == "dgemm_64" || funcName == "dgemm_64_" || - funcName == "dgemm" || funcName == "dgemm_") { - TypeTree ptrint; - ptrint.insert({-1}, BaseType::Pointer); - ptrint.insert({-1, 0}, BaseType::Integer); - // transa, transb, m, n, k, lda, ldb, ldc - for (int i : {0, 1, 2, 3, 4, 7, 9, 12}) - updateAnalysis(call.getArgOperand(i), ptrint, &call); - - TypeTree ptrdbl; - ptrdbl.insert({-1}, BaseType::Pointer); - ptrdbl.insert({-1, 0}, Type::getDoubleTy(call.getContext())); - - // alpha, a, b, beta, c - for (int i : {5, 6, 8, 10, 11}) - updateAnalysis(call.getArgOperand(i), ptrdbl, &call); - return; - } - - if (funcName == "__kmpc_fork_call") { - Function *fn = dyn_cast(call.getArgOperand(2)); - - if (auto castinst = dyn_cast(call.getArgOperand(2))) - if (castinst->isCast()) - fn = dyn_cast(castinst->getOperand(0)); - - if (fn) { -#if LLVM_VERSION_MAJOR >= 14 - if (call.arg_size() - 3 != fn->getFunctionType()->getNumParams() - 2) - return; -#else - if (call.getNumArgOperands() - 3 != - fn->getFunctionType()->getNumParams() - 2) - return; -#endif - - if (direction & UP) { - FnTypeInfo typeInfo(fn); - - TypeTree IntPtr; - IntPtr.insert({-1, -1}, BaseType::Integer); - IntPtr.insert({-1}, BaseType::Pointer); - - int argnum = 0; - for (auto &arg : fn->args()) { - if (argnum <= 1) { - typeInfo.Arguments.insert( - std::pair(&arg, IntPtr)); - typeInfo.KnownValues.insert( - std::pair>(&arg, {0})); - } else { - typeInfo.Arguments.insert(std::pair( - &arg, getAnalysis(call.getArgOperand(argnum - 2 + 3)))); - std::set bounded; - for (auto v : fntypeinfo.knownIntegralValues( - call.getArgOperand(argnum - 2 + 3), DT, intseen, SE)) { - if (abs(v) > MaxIntOffset) - continue; - bounded.insert(v); - } - typeInfo.KnownValues.insert( - std::pair>(&arg, bounded)); - } - - ++argnum; - } - - if (EnzymePrintType) { - llvm::errs() << " starting omp IPO of "; - call.print(llvm::errs(), *MST); - llvm::errs() << "\n"; - } - - auto a = fn->arg_begin(); - ++a; - ++a; - TypeResults STR = interprocedural.analyzeFunction(typeInfo); -#if LLVM_VERSION_MAJOR >= 14 - for (unsigned i = 3; i < call.arg_size(); ++i) -#else - for (unsigned i = 3; i < call.getNumArgOperands(); ++i) -#endif - { - auto dt = STR.query(a); - updateAnalysis(call.getArgOperand(i), dt, &call); - ++a; - } - } - } - return; - } - if (funcName == "__kmpc_for_static_init_4" || - funcName == "__kmpc_for_static_init_4u" || - funcName == "__kmpc_for_static_init_8" || - funcName == "__kmpc_for_static_init_8u") { - TypeTree ptrint; - ptrint.insert({-1}, BaseType::Pointer); - size_t numBytes = 4; - if (funcName == "__kmpc_for_static_init_8" || - funcName == "__kmpc_for_static_init_8u") - numBytes = 8; - for (size_t i = 0; i < numBytes; i++) - ptrint.insert({-1, (int)i}, BaseType::Integer); - updateAnalysis(call.getArgOperand(3), ptrint, &call); - updateAnalysis(call.getArgOperand(4), ptrint, &call); - updateAnalysis(call.getArgOperand(5), ptrint, &call); - updateAnalysis(call.getArgOperand(6), ptrint, &call); - updateAnalysis(call.getArgOperand(7), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getArgOperand(8), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "omp_get_max_threads" || funcName == "omp_get_thread_num" || - funcName == "omp_get_num_threads" || - funcName == "__kmpc_global_thread_num") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "_ZNSt6localeC1Ev") { - TypeTree ptrint; - ptrint.insert({-1}, BaseType::Pointer); - ptrint.insert({-1, 0}, BaseType::Integer); - updateAnalysis(call.getOperand(0), ptrint, &call); - return; - } - - if (startsWith(funcName, "_ZNKSt3__14hash")) { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - - if (startsWith(funcName, "_ZNKSt3__112basic_string") || - startsWith(funcName, "_ZNSt3__112basic_string") || - startsWith(funcName, "_ZNSt3__112__hash_table") || - startsWith(funcName, "_ZNKSt3__115basic_stringbuf")) { - return; - } - - if (funcName == "__dynamic_cast" || - funcName == "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base" || - funcName == "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base" || - funcName == "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base" || - funcName == "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base") { - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "memcmp") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - - /// CUDA - if (funcName == "cuDeviceGet") { - // cuResult - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "cuDeviceGetName") { - // cuResult - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "cudaRuntimeGetVersion" || - funcName == "cuDriverGetVersion" || funcName == "cuDeviceGetCount") { - // cuResult - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - TypeTree ptrint; - ptrint.insert({-1}, BaseType::Pointer); - ptrint.insert({-1, 0}, BaseType::Integer); - updateAnalysis(call.getOperand(0), ptrint, &call); - return; - } - if (funcName == "cuMemGetInfo_v2") { - // cuResult - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - TypeTree ptrint; - ptrint.insert({-1}, BaseType::Pointer); - ptrint.insert({-1, 0}, BaseType::Integer); - updateAnalysis(call.getOperand(0), ptrint, &call); - updateAnalysis(call.getOperand(1), ptrint, &call); - return; - } - if (funcName == "cuDevicePrimaryCtxRetain" || - funcName == "cuCtxGetCurrent") { - // cuResult - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "cuStreamQuery") { - // cuResult - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "cuMemAllocAsync" || funcName == "cuMemAlloc" || - funcName == "cuMemAlloc_v2" || funcName == "cudaMalloc" || - funcName == "cudaMallocAsync" || funcName == "cudaMallocHost" || - funcName == "cudaMallocFromPoolAsync") { - TypeTree ptrptr; - ptrptr.insert({-1}, BaseType::Pointer); - ptrptr.insert({-1, 0}, BaseType::Pointer); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), ptrptr, &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "jl_hrtime" || funcName == "ijl_hrtime") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "jl_get_task_tid" || funcName == "ijl_get_task_tid") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "jl_get_binding_or_error" || - funcName == "ijl_get_binding_or_error") { - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "julia.gc_loaded") { - if (direction & UP) - updateAnalysis(call.getArgOperand(1), getAnalysis(&call), &call); - if (direction & DOWN) - updateAnalysis(&call, getAnalysis(call.getArgOperand(1)), &call); - return; - } - if (funcName == "julia.pointer_from_objref") { - if (direction & UP) - updateAnalysis(call.getArgOperand(0), getAnalysis(&call), &call); - if (direction & DOWN) - updateAnalysis(&call, getAnalysis(call.getArgOperand(0)), &call); - return; - } - if (funcName == "_ZNSt6chrono3_V212steady_clock3nowEv") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - - /// MPI - if (startsWith(funcName, "PMPI_")) - funcName = funcName.substr(1); - if (funcName == "MPI_Init") { - TypeTree ptrint; - ptrint.insert({-1}, BaseType::Pointer); - ptrint.insert({-1, 0}, BaseType::Integer); - updateAnalysis(call.getOperand(0), ptrint, &call); - TypeTree ptrptrptr; - ptrptrptr.insert({-1}, BaseType::Pointer); - ptrptrptr.insert({-1, -1}, BaseType::Pointer); - ptrptrptr.insert({-1, -1, 0}, BaseType::Pointer); - updateAnalysis(call.getOperand(1), ptrptrptr, &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Comm_size" || funcName == "MPI_Comm_rank" || - funcName == "MPI_Get_processor_name") { - TypeTree ptrint; - ptrint.insert({-1}, BaseType::Pointer); - ptrint.insert({-1, 0}, BaseType::Integer); - updateAnalysis(call.getOperand(1), ptrint, &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Barrier" || funcName == "MPI_Finalize") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Send" || funcName == "MPI_Ssend" || - funcName == "MPI_Bsend" || funcName == "MPI_Recv" || - funcName == "MPI_Brecv" || funcName == "PMPI_Send" || - funcName == "PMPI_Ssend" || funcName == "PMPI_Bsend" || - funcName == "PMPI_Recv" || funcName == "PMPI_Brecv") { - TypeTree buf = TypeTree(BaseType::Pointer); - - if (Constant *C = dyn_cast(call.getOperand(2))) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { - buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_float") { - buf.insert({0}, Type::getFloatTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_cxx_bool") { - buf.insert({0}, BaseType::Integer); - } - } else if (auto CI = dyn_cast(C)) { - // MPICH - if (CI->getValue() == 1275070475) { - buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (CI->getValue() == 1275069450) { - buf.insert({0}, Type::getFloatTy(C->getContext())); - } - } - } - updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(3), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(4), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" || - funcName == "PMPI_Isend" || funcName == "PMPI_Irecv") { - TypeTree buf = TypeTree(BaseType::Pointer); - - if (Constant *C = dyn_cast(call.getOperand(2))) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { - buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_float") { - buf.insert({0}, Type::getFloatTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_cxx_bool") { - buf.insert({0}, BaseType::Integer); - } - } else if (auto CI = dyn_cast(C)) { - // MPICH - if (CI->getValue() == 1275070475) { - buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (CI->getValue() == 1275069450) { - buf.insert({0}, Type::getFloatTy(C->getContext())); - } - } - } - updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(3), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(4), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(6), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Wait") { - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Waitany") { - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(3), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Waitall") { - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Bcast") { - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(3), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Reduce" || funcName == "PMPI_Reduce") { - TypeTree buf = TypeTree(BaseType::Pointer); - - if (Constant *C = dyn_cast(call.getOperand(3))) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { - buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_float") { - buf.insert({0}, Type::getFloatTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_cxx_bool") { - buf.insert({0}, BaseType::Integer); - } - } else if (auto CI = dyn_cast(C)) { - // MPICH - if (CI->getValue() == 1275070475) { - buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (CI->getValue() == 1275069450) { - buf.insert({0}, Type::getFloatTy(C->getContext())); - } - } - } - // int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, - // MPI_Datatype datatype, - // MPI_Op op, int root, MPI_Comm comm) - // sendbuf - updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call); - // recvbuf - updateAnalysis(call.getOperand(1), buf.Only(-1, &call), &call); - // count - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - // datatype - // op - // comm - // result - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Allreduce" || funcName == "PMPI_Allreduce") { - TypeTree buf = TypeTree(BaseType::Pointer); - - if (Constant *C = dyn_cast(call.getOperand(3))) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { - buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_float") { - buf.insert({0}, Type::getFloatTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_cxx_bool") { - buf.insert({0}, BaseType::Integer); - } - } else if (auto CI = dyn_cast(C)) { - // MPICH - if (CI->getValue() == 1275070475) { - buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (CI->getValue() == 1275069450) { - buf.insert({0}, Type::getFloatTy(C->getContext())); - } - } - } - // int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, - // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) - // sendbuf - updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call); - // recvbuf - updateAnalysis(call.getOperand(1), buf.Only(-1, &call), &call); - // count - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - // datatype - // op - // comm - // result - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Sendrecv_replace") { - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(3), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(4), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(5), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(6), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(8), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Sendrecv") { - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(3), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(4), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(5), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(6), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(7), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(8), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(9), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(11), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Gather" || funcName == "MPI_Scatter") { - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(3), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(4), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(6), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "MPI_Allgather") { - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(3), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(4), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - /// END MPI - - // Prob Prog - if (ci->hasFnAttribute("enzyme_notypeanalysis")) { - return; - } - - if (funcName == "memcpy" || funcName == "memmove") { - // TODO have this call common mem transfer to copy data - visitMemTransferCommon(call); - return; - } - if (funcName == "posix_memalign") { - TypeTree ptrptr; - ptrptr.insert({-1}, BaseType::Pointer); - ptrptr.insert({-1, 0}, BaseType::Pointer); - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), ptrptr, &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "calloc") { - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (auto opidx = getAllocationIndexFromCall(&call)) { - auto ptr = TypeTree(BaseType::Pointer); - unsigned index = (size_t)*opidx; - if (auto CI = dyn_cast(call.getOperand(index))) { - auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); - auto LoadSize = CI->getZExtValue(); - // Only propagate mappings in range that aren't "Anything" into the - // pointer - ptr |= getAnalysis(&call).Lookup(LoadSize, DL); - } - updateAnalysis(&call, ptr.Only(-1, &call), &call); - updateAnalysis(call.getOperand(index), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "malloc") { - auto ptr = TypeTree(BaseType::Pointer); - if (auto CI = dyn_cast(call.getOperand(0))) { - auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); - auto LoadSize = CI->getZExtValue(); - // Only propagate mappings in range that aren't "Anything" into the - // pointer - ptr |= getAnalysis(&call).Lookup(LoadSize, DL); - } - updateAnalysis(&call, ptr.Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "__size_returning_new_experiment") { - auto ptr = TypeTree(BaseType::Pointer); - auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); - if (auto CI = dyn_cast(call.getOperand(0))) { - auto LoadSize = CI->getZExtValue(); - // Only propagate mappings in range that aren't "Anything" into the - // pointer - ptr |= getAnalysis(&call).Lookup(LoadSize, DL); - } - ptr = ptr.Only(0, &call); - ptr |= TypeTree(BaseType::Integer).Only(DL.getPointerSize(), &call); - updateAnalysis(&call, ptr.Only(0, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" || - funcName == "ijl_gc_alloc_typed") { - auto ptr = TypeTree(BaseType::Pointer); - if (auto CI = dyn_cast(call.getOperand(1))) { - auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); - auto LoadSize = CI->getZExtValue(); - // Only propagate mappings in range that aren't "Anything" into the - // pointer - ptr |= getAnalysis(&call).Lookup(LoadSize, DL); - } - updateAnalysis(&call, ptr.Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "julia.except_enter" || funcName == "ijl_excstack_state" || - funcName == "jl_excstack_state") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "jl_array_copy" || funcName == "ijl_array_copy" || - funcName == "jl_inactive_inout" || - funcName == "jl_genericmemory_copy_slice" || - funcName == "ijl_genericmemory_copy_slice") { - if (direction & DOWN) - updateAnalysis(&call, getAnalysis(call.getOperand(0)), &call); - if (direction & UP) - updateAnalysis(call.getOperand(0), getAnalysis(&call), &call); - return; - } - - if (isAllocationFunction(funcName, TLI)) { - size_t Idx = 0; - for (auto &Arg : ci->args()) { - if (Arg.getType()->isIntegerTy()) { - updateAnalysis(call.getOperand(Idx), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - } - Idx++; - } - assert(ci->getReturnType()->isPointerTy()); - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "malloc_usable_size" || funcName == "malloc_size" || - funcName == "_msize") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "realloc") { - size_t sz = 1; - for (auto val : fntypeinfo.knownIntegralValues(call.getArgOperand(1), DT, - intseen, SE)) { - if (val >= 0) { - sz = max(sz, (size_t)val); - } - } - - auto &dl = call.getParent()->getParent()->getParent()->getDataLayout(); - TypeTree res = getAnalysis(call.getArgOperand(0)) - .PurgeAnything() - .Data0() - .ShiftIndices(dl, 0, sz, 0); - TypeTree res2 = - getAnalysis(&call).PurgeAnything().Data0().ShiftIndices(dl, 0, sz, 0); - - res.orIn(res2, /*PointerIntSame*/ false); - res.insert({}, BaseType::Pointer); - res = res.Only(-1, &call); - if (direction & DOWN) { - updateAnalysis(&call, res, &call); - } - if (direction & UP) { - updateAnalysis(call.getOperand(0), res, &call); - } - return; - } - if (funcName == "sigaction") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "mmap") { - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(3), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(4), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(5), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "munmap") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "pthread_mutex_lock" || - funcName == "pthread_mutex_trylock" || - funcName == "pthread_rwlock_rdlock" || - funcName == "pthread_rwlock_unlock" || - funcName == "pthread_attr_init" || funcName == "pthread_attr_destroy" || - funcName == "pthread_rwlock_unlock" || - funcName == "pthread_mutex_unlock") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (isDeallocationFunction(funcName, TLI)) { - size_t Idx = 0; - for (auto &Arg : ci->args()) { - if (Arg.getType()->isIntegerTy()) { - updateAnalysis(call.getOperand(Idx), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - } - if (Arg.getType()->isPointerTy()) { - updateAnalysis(call.getOperand(Idx), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - } - Idx++; - } - if (!ci->getReturnType()->isVoidTy()) { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), - &call); - return; - } - assert(ci->getReturnType()->isVoidTy()); - return; - } - if (funcName == "memchr" || funcName == "memrchr") { - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "strlen") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "strcmp") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "bcmp") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "getcwd") { - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "sysconf") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "dladdr") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "__errno_location") { - TypeTree ptrint; - ptrint.insert({-1, -1}, BaseType::Integer); - ptrint.insert({-1}, BaseType::Pointer); - updateAnalysis(&call, ptrint, &call); - return; - } - if (funcName == "getenv") { - TypeTree ptrint; - ptrint.insert({-1, -1}, BaseType::Integer); - ptrint.insert({-1}, BaseType::Pointer); - updateAnalysis(&call, ptrint, &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "getcwd") { - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "mprotect") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "memcmp") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "signal") { - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - return; - } - if (funcName == "write" || funcName == "read" || funcName == "writev" || - funcName == "readv") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - // FD type not going to be defined here - // updateAnalysis(call.getOperand(0), - // TypeTree(BaseType::Pointer).Only(-1), - // &call); - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); - updateAnalysis(call.getOperand(2), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - if (funcName == "gsl_sf_legendre_array_e") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - return; - } - - // CONSIDER(__lgamma_r_finite) - - CONSIDER2(frexp, double, double, int *) - CONSIDER(frexpf) - CONSIDER(frexpl) - CONSIDER2(ldexp, double, double, int) - CONSIDER2(modf, double, double, double *) - CONSIDER(modff) - CONSIDER(modfl) - - CONSIDER2(remquo, double, double, double, int *) - CONSIDER(remquof) - CONSIDER(remquol) - - if (isMemFreeLibMFunction(funcName)) { -#if LLVM_VERSION_MAJOR >= 14 - for (size_t i = 0; i < call.arg_size(); ++i) -#else - for (size_t i = 0; i < call.getNumArgOperands(); ++i) -#endif - { - Type *T = call.getArgOperand(i)->getType(); - if (T->isFloatingPointTy()) { - updateAnalysis( - call.getArgOperand(i), - TypeTree(ConcreteType( - call.getArgOperand(i)->getType()->getScalarType())) - .Only(-1, &call), - &call); - } else if (T->isIntegerTy()) { - updateAnalysis(call.getArgOperand(i), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - } else if (auto ST = dyn_cast(T)) { - assert(ST->getNumElements() >= 1); - for (size_t i = 1; i < ST->getNumElements(); ++i) { - assert(ST->getTypeAtIndex((unsigned)0) == ST->getTypeAtIndex(i)); - } - if (ST->getTypeAtIndex((unsigned)0)->isFloatingPointTy()) - updateAnalysis( - call.getArgOperand(i), - TypeTree(ConcreteType( - ST->getTypeAtIndex((unsigned)0)->getScalarType())) - .Only(-1, &call), - &call); - else if (ST->getTypeAtIndex((unsigned)0)->isIntegerTy()) { - updateAnalysis(call.getArgOperand(i), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - } else { - llvm::errs() << *T << " - " << call << "\n"; - llvm_unreachable("Unknown type for libm"); - } - } else if (auto AT = dyn_cast(T)) { - assert(AT->getNumElements() >= 1); - if (AT->getElementType()->isFloatingPointTy()) - updateAnalysis( - call.getArgOperand(i), - TypeTree(ConcreteType(AT->getElementType()->getScalarType())) - .Only(-1, &call), - &call); - else if (AT->getElementType()->isIntegerTy()) { - updateAnalysis(call.getArgOperand(i), - TypeTree(BaseType::Integer).Only(-1, &call), &call); - } else { - llvm::errs() << *T << " - " << call << "\n"; - llvm_unreachable("Unknown type for libm"); - } - } else { - llvm::errs() << *T << " - " << call << "\n"; - llvm_unreachable("Unknown type for libm"); - } - } - Type *T = call.getType(); - if (T->isFloatingPointTy()) { - updateAnalysis(&call, - TypeTree(ConcreteType(call.getType()->getScalarType())) - .Only(-1, &call), - &call); - } else if (T->isIntegerTy()) { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), - &call); - } else if (T->isVoidTy()) { - } else if (auto ST = dyn_cast(T)) { - assert(ST->getNumElements() >= 1); - TypeTree TT; - auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); - for (size_t i = 0; i < ST->getNumElements(); ++i) { - auto T = ST->getTypeAtIndex(i); - ConcreteType CT(BaseType::Unknown); - - Value *vec[2] = { - ConstantInt::get(Type::getInt64Ty(call.getContext()), 0), - ConstantInt::get(Type::getInt32Ty(call.getContext()), i)}; - auto ud = UndefValue::get(PointerType::getUnqual(ST)); - auto g2 = GetElementPtrInst::Create(ST, ud, vec); - APInt ai(DL.getIndexSizeInBits(0), 0); - g2->accumulateConstantOffset(DL, ai); - delete g2; - size_t Offset = ai.getZExtValue(); - - size_t nextOffset; - if (i + 1 == ST->getNumElements()) - nextOffset = (DL.getTypeSizeInBits(ST) + 7) / 8; - else { - Value *vec[2] = { - ConstantInt::get(Type::getInt64Ty(call.getContext()), 0), - ConstantInt::get(Type::getInt32Ty(call.getContext()), i + 1)}; - auto ud = UndefValue::get(PointerType::getUnqual(ST)); - auto g2 = GetElementPtrInst::Create(ST, ud, vec); - APInt ai(DL.getIndexSizeInBits(0), 0); - g2->accumulateConstantOffset(DL, ai); - delete g2; - nextOffset = ai.getZExtValue(); - } - - if (T->isFloatingPointTy()) { - CT = T; - } else if (T->isIntegerTy()) { - CT = BaseType::Integer; - } - if (CT != BaseType::Unknown) { - TypeTree mid = TypeTree(CT).Only(-1, &call); - TT |= mid.ShiftIndices(DL, /*init offset*/ 0, - /*maxSize*/ nextOffset - Offset, - /*addOffset*/ Offset); - } - } - auto Size = (DL.getTypeSizeInBits(ST) + 7) / 8; - TT.CanonicalizeInPlace(Size, DL); - updateAnalysis(&call, TT, &call); - } else if (auto AT = dyn_cast(T)) { - assert(AT->getNumElements() >= 1); - if (AT->getElementType()->isFloatingPointTy()) - updateAnalysis( - &call, - TypeTree(ConcreteType(AT->getElementType()->getScalarType())) - .Only(-1, &call), - &call); - else { - llvm::errs() << *T << " - " << call << "\n"; - llvm_unreachable("Unknown type for libm"); - } - } else { - llvm::errs() << *T << " - " << call << "\n"; - llvm_unreachable("Unknown type for libm"); - } - return; - } - if (funcName == "__lgamma_r_finite") { - updateAnalysis( - call.getArgOperand(0), - TypeTree(ConcreteType(Type::getDoubleTy(call.getContext()))) - .Only(-1, &call), - &call); - updateAnalysis(call.getArgOperand(1), - TypeTree(BaseType::Integer).Only(0, &call).Only(-1, &call), - &call); - updateAnalysis( - &call, - TypeTree(ConcreteType(Type::getDoubleTy(call.getContext()))) - .Only(-1, &call), - &call); - } - if (funcName == "__fd_sincos_1" || funcName == "__fd_sincos_1f" || - funcName == "__fd_sincos_1l") { - updateAnalysis(call.getArgOperand(0), - TypeTree(ConcreteType(call.getArgOperand(0)->getType())) - .Only(-1, &call), - &call); - updateAnalysis(&call, - TypeTree(ConcreteType(call.getArgOperand(0)->getType())) - .Only(-1, &call), - &call); - } - if (funcName == "frexp" || funcName == "frexpf" || funcName == "frexpl") { - - updateAnalysis( - &call, TypeTree(ConcreteType(call.getType())).Only(-1, &call), &call); - updateAnalysis(call.getOperand(0), - TypeTree(ConcreteType(call.getType())).Only(-1, &call), - &call); - TypeTree ival(BaseType::Pointer); - size_t objSize = 1; - -#if LLVM_VERSION_MAJOR < 17 - auto &DL = fntypeinfo.Function->getParent()->getDataLayout(); - objSize = DL.getTypeSizeInBits( - call.getOperand(1)->getType()->getPointerElementType()) / - 8; -#endif - for (size_t i = 0; i < objSize; ++i) { - ival.insert({(int)i}, BaseType::Integer); - } - updateAnalysis(call.getOperand(1), ival.Only(-1, &call), &call); - return; - } - - if (funcName == "__cxa_guard_acquire" || funcName == "printf" || - funcName == "vprintf" || funcName == "puts" || funcName == "fputc" || - funcName == "fprintf") { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); - } - - if (dontAnalyze(funcName)) - return; - - if (!ci->empty() && !hasMetadata(ci, "enzyme_gradient") && - !hasMetadata(ci, "enzyme_derivative")) { - visitIPOCall(call, *ci); - } - } -} - -TypeTree TypeAnalyzer::getReturnAnalysis() { - bool set = false; - TypeTree vd; - for (BasicBlock &BB : *fntypeinfo.Function) { - for (auto &inst : BB) { - if (auto ri = dyn_cast(&inst)) { - if (auto rv = ri->getReturnValue()) { - if (set == false) { - set = true; - vd = getAnalysis(rv); - continue; - } - vd &= getAnalysis(rv); - // TODO insert the selectinst anything propagation here - // however this needs to be done simultaneously with preventing - // anything from propagating up through the return value (if there - // are multiple possible returns) - } - } - } - } - return vd; -} - -/// Helper function that calculates whether a given value must only be -/// an integer and cannot be cast/stored to be used as a ptr/integer -bool TypeAnalyzer::mustRemainInteger(Value *val, bool *returned) { - std::map> &seen = mriseen; - const DataLayout &DL = fntypeinfo.Function->getParent()->getDataLayout(); - if (seen.find(val) != seen.end()) { - if (returned) - *returned |= seen[val].second; - return seen[val].first; - } - seen[val] = std::make_pair(true, false); - for (auto u : val->users()) { - if (auto SI = dyn_cast(u)) { - if (parseTBAA(*SI, DL, MST).Inner0().isIntegral()) - continue; - seen[val].first = false; - continue; - } - if (isa(u)) { - if (!u->getType()->isIntOrIntVectorTy()) { - seen[val].first = false; - continue; - } else if (!mustRemainInteger(u, returned)) { - seen[val].first = false; - seen[val].second |= seen[u].second; - continue; - } else - continue; - } - if (isa(u) || isa(u) || isa(u) || -#if LLVM_VERSION_MAJOR <= 17 - isa(u) || isa(u) || -#endif - isa(u) || isa(u) || isa(u) || - isa(u) || isa(u)) { - if (!mustRemainInteger(u, returned)) { - seen[val].first = false; - seen[val].second |= seen[u].second; - } - continue; - } - if (auto gep = dyn_cast(u)) { - if (gep->isInBounds() && gep->getPointerOperand() != val) { - continue; - } - } - if (returned && isa(u)) { - *returned = true; - seen[val].second = true; - continue; - } - if (auto CI = dyn_cast(u)) { - if (auto F = CI->getCalledFunction()) { - if (!F->empty()) { - int argnum = 0; - bool subreturned = false; - for (auto &arg : F->args()) { - if (CI->getArgOperand(argnum) == val && - !mustRemainInteger(&arg, &subreturned)) { - seen[val].first = false; - seen[val].second |= seen[&arg].second; - continue; - } - ++argnum; - } - if (subreturned && !mustRemainInteger(CI, returned)) { - seen[val].first = false; - seen[val].second |= seen[CI].second; - continue; - } - continue; - } - } - } - if (isa(u)) - continue; - seen[val].first = false; - seen[val].second = true; - } - if (returned && seen[val].second) - *returned = true; - return seen[val].first; -} - -FnTypeInfo TypeAnalyzer::getCallInfo(CallBase &call, Function &fn) { - FnTypeInfo typeInfo(&fn); - - size_t argnum = 0; - for (auto &arg : fn.args()) { - if (argnum >= call.arg_size()) { - typeInfo.Arguments.insert( - std::pair(&arg, TypeTree())); - std::set bounded; - typeInfo.KnownValues.insert( - std::pair>(&arg, bounded)); - ++argnum; - continue; - } - auto dt = getAnalysis(call.getArgOperand(argnum)); - if (arg.getType()->isIntOrIntVectorTy() && - dt.Inner0() == BaseType::Anything) { - if (mustRemainInteger(&arg)) { - dt = TypeTree(BaseType::Integer).Only(-1, &call); - } - } - typeInfo.Arguments.insert(std::pair(&arg, dt)); - std::set bounded; - for (auto v : fntypeinfo.knownIntegralValues(call.getArgOperand(argnum), DT, - intseen, SE)) { - if (abs(v) > MaxIntOffset) - continue; - bounded.insert(v); - } - typeInfo.KnownValues.insert( - std::pair>(&arg, bounded)); - ++argnum; - } - - typeInfo.Return = getAnalysis(&call); - return typeInfo; -} - -void TypeAnalyzer::visitIPOCall(CallBase &call, Function &fn) { -#if LLVM_VERSION_MAJOR >= 14 - if (call.arg_size() != fn.getFunctionType()->getNumParams()) - return; -#else - if (call.getNumArgOperands() != fn.getFunctionType()->getNumParams()) - return; -#endif - - assert(fntypeinfo.KnownValues.size() == - fntypeinfo.Function->getFunctionType()->getNumParams()); - - bool hasDown = direction & DOWN; - bool hasUp = direction & UP; - - if (hasDown) { - if (call.getType()->isVoidTy()) - hasDown = false; - else { - if (getAnalysis(&call).IsFullyDetermined()) - hasDown = false; - } - } - if (hasUp) { - bool unknown = false; -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : call.args()) -#else - for (auto &arg : call.arg_operands()) -#endif - { - if (isa(arg)) - continue; - if (!getAnalysis(arg).IsFullyDetermined()) { - unknown = true; - break; - } - } - if (!unknown) - hasUp = false; - } - - // Fast path where all information has already been derived - if (!hasUp && !hasDown) - return; - - FnTypeInfo typeInfo = getCallInfo(call, fn); - typeInfo = preventTypeAnalysisLoops(typeInfo, call.getParent()->getParent()); - - if (EnzymePrintType) { - llvm::errs() << " starting IPO of "; - call.print(llvm::errs(), *MST); - llvm::errs() << "\n"; - } - - TypeResults STR = interprocedural.analyzeFunction(typeInfo); - - if (EnzymePrintType) { - llvm::errs() << " ending IPO of "; - call.print(llvm::errs(), *MST); - llvm::errs() << "\n"; - } - - if (hasUp) { - auto a = fn.arg_begin(); -#if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : call.args()) -#else - for (auto &arg : call.arg_operands()) -#endif - { - auto dt = STR.query(a); - if (EnzymePrintType) { - llvm::errs() << " updating "; - arg->print(llvm::errs(), *MST); - llvm::errs() << " = " << dt.str() << " via IPO of "; - call.print(llvm::errs(), *MST); - llvm::errs() << " arg "; - a->print(llvm::errs(), *MST); - llvm::errs() << "\n"; - } - updateAnalysis(arg, dt, &call); - ++a; - } - } - - if (hasDown) { - TypeTree vd = STR.getReturnAnalysis(); - if (call.getType()->isIntOrIntVectorTy() && - vd.Inner0() == BaseType::Anything) { - bool returned = false; - if (mustRemainInteger(&call, &returned) && !returned) { - vd = TypeTree(BaseType::Integer).Only(-1, &call); - } - } - updateAnalysis(&call, vd, &call); - } -} - -TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { - assert(fn.KnownValues.size() == - fn.Function->getFunctionType()->getNumParams()); - assert(fn.Function); - auto found = analyzedFunctions.find(fn); - if (found != analyzedFunctions.end()) { - auto &analysis = *found->second; - if (analysis.fntypeinfo.Function != fn.Function) { - llvm::errs() << " queryFunc: " << *fn.Function << "\n"; - llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.Function - << "\n"; - } - assert(analysis.fntypeinfo.Function == fn.Function); - - return TypeResults(analysis); - } - - if (fn.Function->empty()) - return TypeResults(nullptr); - - auto res = analyzedFunctions.emplace(fn, new TypeAnalyzer(fn, *this)); - auto &analysis = *res.first->second; - - if (EnzymePrintType) { - llvm::errs() << "analyzing function " << fn.Function->getName() << "\n"; - for (auto &pair : fn.Arguments) { - llvm::errs() << " + knowndata: "; - pair.first->print(llvm::errs(), *analysis.MST); - llvm::errs() << " : " << pair.second.str(); - auto found = fn.KnownValues.find(pair.first); - if (found != fn.KnownValues.end()) { - llvm::errs() << " - " << to_string(found->second); - } - llvm::errs() << "\n"; - } - llvm::errs() << " + retdata: " << fn.Return.str() << "\n"; - } - - analysis.prepareArgs(); - if (RustTypeRules) { - analysis.considerRustDebugInfo(); - } - analysis.considerTBAA(); - analysis.run(); - - if (analysis.fntypeinfo.Function != fn.Function) { - llvm::errs() << " queryFunc: " << *fn.Function << "\n"; - llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.Function << "\n"; - } - assert(analysis.fntypeinfo.Function == fn.Function); - - { - auto &analysis = *analyzedFunctions.find(fn)->second; - if (analysis.fntypeinfo.Function != fn.Function) { - llvm::errs() << " queryFunc: " << *fn.Function << "\n"; - llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.Function - << "\n"; - } - assert(analysis.fntypeinfo.Function == fn.Function); - } - - // Store the steady state result (if changed) to avoid - // a second analysis later. - analyzedFunctions.emplace(TypeResults(analysis).getAnalyzedTypeInfo(), - res.first->second); - - return TypeResults(analysis); -} - -TypeResults::TypeResults(TypeAnalyzer &analyzer) : analyzer(&analyzer) {} -TypeResults::TypeResults(std::nullptr_t) : analyzer(nullptr) {} - -FnTypeInfo TypeResults::getAnalyzedTypeInfo() const { - FnTypeInfo res(analyzer->fntypeinfo.Function); - for (auto &arg : analyzer->fntypeinfo.Function->args()) { - res.Arguments.insert(std::pair(&arg, query(&arg))); - } - res.Return = getReturnAnalysis(); - res.KnownValues = analyzer->fntypeinfo.KnownValues; - return res; -} - -FnTypeInfo TypeResults::getCallInfo(CallBase &CI, Function &fn) const { - return analyzer->getCallInfo(CI, fn); -} - -TypeTree TypeResults::query(Value *val) const { -#ifndef NDEBUG - if (auto inst = dyn_cast(val)) { - assert(inst->getParent()->getParent() == analyzer->fntypeinfo.Function); - } - if (auto arg = dyn_cast(val)) { - assert(arg->getParent() == analyzer->fntypeinfo.Function); - } -#endif - return analyzer->getAnalysis(val); -} - -// Returns last non-padding/alignment location of the corresponding subtype T. -size_t skippedBytes(SmallSet &offs, Type *T, const DataLayout &DL, - size_t offset = 0) { - auto ST = dyn_cast(T); - if (!ST) - return (DL.getTypeSizeInBits(T) + 7) / 8; - - auto SL = DL.getStructLayout(ST); - size_t prevOff = 0; - for (size_t idx = 0; idx < ST->getNumElements(); idx++) { - auto off = SL->getElementOffset(idx); - if (off > prevOff) - for (size_t i = prevOff; i < off; i++) - offs.insert(offset + i); - size_t subSize = skippedBytes(offs, ST->getElementType(idx), DL, prevOff); - prevOff = off + subSize; - } - return prevOff; -} - -bool TypeResults::anyFloat(Value *val, bool anythingIsFloat) const { - assert(val); - assert(val->getType()); - auto q = query(val); - auto dt = q[{-1}]; - if (!anythingIsFloat && dt == BaseType::Anything) - return false; - if (dt != BaseType::Anything && dt != BaseType::Unknown) - return dt.isFloat(); - - if (val->getType()->isTokenTy() || val->getType()->isVoidTy()) - return false; - auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout(); - SmallSet offs; - size_t ObjSize = skippedBytes(offs, val->getType(), dl); - - for (size_t i = 0; i < ObjSize;) { - dt = q[{(int)i}]; - if (dt == BaseType::Integer) { - i++; - continue; - } - if (!anythingIsFloat && dt == BaseType::Integer) { - i++; - continue; - } - if (dt == BaseType::Pointer) { - i += dl.getPointerSize(0); - continue; - } - if (offs.count(i)) { - i++; - continue; - } - return true; - } - return false; -} - -bool TypeResults::anyPointer(Value *val) const { - assert(val); - assert(val->getType()); - auto q = query(val); - auto dt = q[{-1}]; - if (dt != BaseType::Anything && dt != BaseType::Unknown) - return dt == BaseType::Pointer; - if (val->getType()->isTokenTy() || val->getType()->isVoidTy()) - return false; - - auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout(); - SmallSet offs; - size_t ObjSize = skippedBytes(offs, val->getType(), dl); - - for (size_t i = 0; i < ObjSize;) { - dt = q[{(int)i}]; - if (dt == BaseType::Integer) { - i++; - continue; - } - if (auto FT = dt.isFloat()) { - i += (dl.getTypeSizeInBits(FT) + 7) / 8; - continue; - } - if (offs.count(i)) { - i++; - continue; - } - return true; - } - return false; -} - -void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer->dump(ss); } - -ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound, - bool pointerIntSame) const { - assert(val); - assert(val->getType()); - auto q = query(val); - auto dt = q[{0}]; - /* - size_t ObjSize = 1; - if (val->getType()->isSized()) - ObjSize = (fn.Function->getParent()->getDataLayout().getTypeSizeInBits( - val->getType()) +7) / 8; - */ - dt.orIn(q[{-1}], pointerIntSame); - for (size_t i = 1; i < num; ++i) { - dt.orIn(q[{(int)i}], pointerIntSame); - } - - if (errIfNotFound && (!dt.isKnown() || dt == BaseType::Anything)) { - if (auto inst = dyn_cast(val)) { - llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n"; - llvm::errs() << *inst->getParent()->getParent() << "\n"; - for (auto &pair : analyzer->analysis) { - llvm::errs() << "val: " << *pair.first << " - " << pair.second.str() - << "\n"; - } - } - llvm::errs() << "could not deduce type of integer " << *val << "\n"; - assert(0 && "could not deduce type of integer"); - } - return dt; -} - -Type *TypeResults::addingType(size_t num, Value *val, size_t start) const { - assert(val); - assert(val->getType()); - auto q = query(val); - Type *ty = q[{-1}].isFloat(); - for (size_t i = start; i < num; ++i) { - auto ty2 = q[{(int)i}].isFloat(); - if (ty) { - if (ty2) - assert(ty == ty2); - } else { - ty = ty2; - } - } - return ty; -} - -ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, - bool errIfNotFound, - bool pointerIntSame) const { - assert(val); - assert(val->getType()); - auto q = query(val).Data0(); - if (!(val->getType()->isPointerTy() || q[{}] == BaseType::Pointer)) { - llvm::errs() << *analyzer->fntypeinfo.Function << "\n"; - dump(); - llvm::errs() << "val: " << *val << "\n"; - } - assert(val->getType()->isPointerTy() || q[{}] == BaseType::Pointer); - - auto dt = q[{-1}]; - for (size_t i = 0; i < num; ++i) { - bool Legal = true; - dt.checkedOrIn(q[{(int)i}], pointerIntSame, Legal); - if (!Legal) { - std::string str; - raw_string_ostream ss(str); - ss << "Illegal firstPointer, num: " << num << " q: " << q.str() << "\n"; - ss << " at " << *val << " from " << *I << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(I), ErrorType::IllegalFirstPointer, - &analyzer, nullptr, nullptr); - } - llvm::errs() << ss.str() << "\n"; - llvm_unreachable("Illegal firstPointer"); - } - } - - if (errIfNotFound && (!dt.isKnown() || dt == BaseType::Anything)) { - auto &res = *analyzer; - if (auto inst = dyn_cast(val)) { - llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n"; - llvm::errs() << *inst->getParent()->getParent() << "\n"; - for (auto &pair : res.analysis) { - if (auto in = dyn_cast(pair.first)) { - if (in->getParent()->getParent() != inst->getParent()->getParent()) { - llvm::errs() << "inf: " << *in->getParent()->getParent() << "\n"; - llvm::errs() << "instf: " << *inst->getParent()->getParent() - << "\n"; - llvm::errs() << "in: " << *in << "\n"; - llvm::errs() << "inst: " << *inst << "\n"; - } - assert(in->getParent()->getParent() == - inst->getParent()->getParent()); - } - llvm::errs() << "val: " << *pair.first << " - " << pair.second.str() - << " int: " + - to_string(res.knownIntegralValues(pair.first)) - << "\n"; - } - } - if (auto arg = dyn_cast(val)) { - llvm::errs() << *arg->getParent() << "\n"; - for (auto &pair : res.analysis) { -#ifndef NDEBUG - if (auto in = dyn_cast(pair.first)) - assert(in->getParent()->getParent() == arg->getParent()); -#endif - llvm::errs() << "val: " << *pair.first << " - " << pair.second.str() - << " int: " + - to_string(res.knownIntegralValues(pair.first)) - << "\n"; - } - } - llvm::errs() << "fn: " << *analyzer->fntypeinfo.Function << "\n"; - dump(); - llvm::errs() << "could not deduce type of integer " << *val - << " num:" << num << " q:" << q.str() << " \n"; - - llvm::DiagnosticLocation loc = - analyzer->fntypeinfo.Function->getSubprogram(); - Instruction *codeLoc = - &*analyzer->fntypeinfo.Function->getEntryBlock().begin(); - if (auto inst = dyn_cast(val)) { - loc = inst->getDebugLoc(); - codeLoc = inst; - } - EmitFailure("CannotDeduceType", loc, codeLoc, - "failed to deduce type of value ", *val); - - assert(0 && "could not deduce type of integer"); - } - return dt; -} - -/// Parse the debug info generated by rustc and retrieve useful type info if -/// possible -void TypeAnalyzer::considerRustDebugInfo() { - DataLayout DL = fntypeinfo.Function->getParent()->getDataLayout(); - for (BasicBlock &BB : *fntypeinfo.Function) { - for (Instruction &I : BB) { - if (DbgDeclareInst *DDI = dyn_cast(&I)) { - TypeTree TT = parseDIType(*DDI, DL); - if (!TT.isKnown()) { - continue; - } - TT |= TypeTree(BaseType::Pointer); - updateAnalysis(DDI->getAddress(), TT.Only(-1, &I), DDI); - } - } - } -} - -TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, - bool intIsPointer) { - if (ET->isIntOrIntVectorTy()) { - if (intIsPointer) - return TypeTree(BaseType::Pointer).Only(-1, I); - else - return TypeTree(BaseType::Integer).Only(-1, I); - } - if (ET->isFPOrFPVectorTy()) { - return TypeTree(ConcreteType(ET->getScalarType())).Only(-1, I); - } - if (ET->isPointerTy()) { - return TypeTree(BaseType::Pointer).Only(-1, I); - } - if (auto ST = dyn_cast(ET)) { - auto &DL = I->getParent()->getParent()->getParent()->getDataLayout(); - - TypeTree Out; - - for (size_t i = 0; i < ST->getNumElements(); i++) { - auto SubT = - defaultTypeTreeForLLVM(ST->getElementType(i), I, intIsPointer); - Value *vec[2] = { - ConstantInt::get(Type::getInt64Ty(I->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(I->getContext()), i), - }; - auto g2 = GetElementPtrInst::Create( - ST, UndefValue::get(PointerType::getUnqual(ST)), vec); - APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(DL, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - delete g2; - - auto size = (DL.getTypeSizeInBits(ST->getElementType(i)) + 7) / 8; - int Off = (int)ai.getLimitedValue(); - Out |= SubT.ShiftIndices(DL, 0, size, Off); - } - return Out; - } - if (auto AT = dyn_cast(ET)) { - auto SubT = defaultTypeTreeForLLVM(AT->getElementType(), I, intIsPointer); - auto &DL = I->getParent()->getParent()->getParent()->getDataLayout(); - - TypeTree Out; - for (size_t i = 0; i < AT->getNumElements(); i++) { - Value *vec[2] = { - ConstantInt::get(Type::getInt64Ty(I->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(I->getContext()), i), - }; - auto g2 = GetElementPtrInst::Create( - AT, UndefValue::get(PointerType::getUnqual(AT)), vec); - APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(DL, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - delete g2; - - int Off = (int)ai.getLimitedValue(); - auto size = (DL.getTypeSizeInBits(AT->getElementType()) + 7) / 8; - Out |= SubT.ShiftIndices(DL, 0, size, Off); - } - return Out; - } - if (auto AT = dyn_cast(ET)) { -#if LLVM_VERSION_MAJOR >= 12 - assert(!AT->getElementCount().isScalable()); - size_t numElems = AT->getElementCount().getKnownMinValue(); -#else - size_t numElems = AT->getNumElements(); -#endif - auto SubT = defaultTypeTreeForLLVM(AT->getElementType(), I, intIsPointer); - auto &DL = I->getParent()->getParent()->getParent()->getDataLayout(); - - TypeTree Out; - for (size_t i = 0; i < numElems; i++) { - Value *vec[2] = { - ConstantInt::get(Type::getInt64Ty(I->getContext()), 0), - ConstantInt::get(Type::getInt32Ty(I->getContext()), i), - }; - auto g2 = GetElementPtrInst::Create( - AT, UndefValue::get(PointerType::getUnqual(AT)), vec); - APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(DL, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - delete g2; - - int Off = (int)ai.getLimitedValue(); - auto size = (DL.getTypeSizeInBits(AT->getElementType()) + 7) / 8; - Out |= SubT.ShiftIndices(DL, 0, size, Off); - } - return Out; - } - // Unhandled/unknown Type - llvm::errs() << "Error Unknown Type: " << *ET << "\n"; - assert(0 && "Error Unknown Type: "); - llvm_unreachable("Error Unknown Type: "); - // return TypeTree(); -} - -Function *TypeResults::getFunction() const { - return analyzer->fntypeinfo.Function; -} - -TypeTree TypeResults::getReturnAnalysis() const { - return analyzer->getReturnAnalysis(); -} - -std::set TypeResults::knownIntegralValues(Value *val) const { - return analyzer->knownIntegralValues(val); -} - -std::set TypeAnalyzer::knownIntegralValues(Value *val) { - return fntypeinfo.knownIntegralValues(val, DT, intseen, SE); -} - -void TypeAnalysis::clear() { analyzedFunctions.clear(); } - -FnTypeInfo preventTypeAnalysisLoops(const FnTypeInfo &oldTypeInfo_, - llvm::Function *todiff) { - FnTypeInfo oldTypeInfo = oldTypeInfo_; - for (auto &pair : oldTypeInfo.KnownValues) { - if (pair.second.size() != 0) { - bool recursiveUse = false; - std::set> seen; - SetVector> todo; - for (auto user : pair.first->users()) - todo.insert(std::make_pair(user, pair.first)); - while (todo.size()) { - auto spair = todo.pop_back_val(); - if (seen.count(spair)) - continue; - seen.insert(spair); - auto [v, prev] = spair; - if (isa(v) || isa(v) || isa(v)) { - for (auto user : v->users()) - todo.insert(std::make_pair(user, v)); - continue; - } - if (auto ci = dyn_cast(v)) { - if (ci->getCalledFunction() == todiff && - ci->getArgOperand(pair.first->getArgNo()) == prev) { - if (prev == pair.first) - continue; - recursiveUse = true; - break; - } - } - } - if (recursiveUse) { - pair.second.clear(); - } - } - } - return oldTypeInfo; -} diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h deleted file mode 100644 index 83a35ec2b31f..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h +++ /dev/null @@ -1,415 +0,0 @@ -//===- TypeAnalysis.h - Declaration of Type Analysis ------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the declaration of Type Analysis, a utility for -// computing the underlying data type of LLVM values. -// -//===----------------------------------------------------------------------===// -#ifndef ENZYME_TYPE_ANALYSIS_H -#define ENZYME_TYPE_ANALYSIS_H 1 - -#include -#include - -#include - -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/StringMap.h" - -#include "llvm/Analysis/TargetLibraryInfo.h" - -#if LLVM_VERSION_MAJOR >= 16 -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "llvm/IR/Constants.h" -#include "llvm/IR/InstVisitor.h" -#include "llvm/IR/ModuleSlotTracker.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" - -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/PostDominators.h" -#include "llvm/IR/Dominators.h" - -#include "../Utils.h" -#include "TypeTree.h" - -extern const llvm::StringMap LIBM_FUNCTIONS; - -static inline bool isMemFreeLibMFunction(llvm::StringRef str, - llvm::Intrinsic::ID *ID = nullptr) { - llvm::StringRef ogstr = str; - if (startsWith(str, "__") && endsWith(str, "_finite")) { - str = str.substr(2, str.size() - 2 - 7); - } else if (startsWith(str, "__fd_") && endsWith(str, "_1")) { - str = str.substr(5, str.size() - 5 - 2); - } else if (startsWith(str, "__nv_")) { - str = str.substr(5, str.size() - 5); - } - if (LIBM_FUNCTIONS.find(str.str()) != LIBM_FUNCTIONS.end()) { - if (ID) - *ID = LIBM_FUNCTIONS.find(str.str())->second; - return true; - } - if (endsWith(str, "f") || endsWith(str, "l") || - (startsWith(ogstr, "__nv_") && endsWith(str, "d"))) { - if (LIBM_FUNCTIONS.find(str.substr(0, str.size() - 1).str()) != - LIBM_FUNCTIONS.end()) { - if (ID) - *ID = LIBM_FUNCTIONS.find(str.substr(0, str.size() - 1).str())->second; - return true; - } - } - return false; -} - -/// Struct containing all contextual type information for a -/// particular function call -struct FnTypeInfo { - /// Function being analyzed - llvm::Function *Function; - - FnTypeInfo(llvm::Function *fn) : Function(fn) {} - FnTypeInfo(const FnTypeInfo &) = default; - FnTypeInfo &operator=(FnTypeInfo &) = default; - FnTypeInfo &operator=(FnTypeInfo &&) = default; - - /// Types of arguments - std::map Arguments; - - /// Type of return - TypeTree Return; - - /// The specific constant(s) known to represented by an argument, if constant - std::map> KnownValues; - - /// The set of known values val will take - std::set - knownIntegralValues(llvm::Value *val, const llvm::DominatorTree &DT, - std::map> &intseen, - llvm::ScalarEvolution &SE) const; -}; - -static inline bool operator<(const FnTypeInfo &lhs, const FnTypeInfo &rhs) { - - if (lhs.Function < rhs.Function) - return true; - if (rhs.Function < lhs.Function) - return false; - - if (lhs.Return < rhs.Return) - return true; - if (rhs.Return < lhs.Return) - return false; - - for (auto &arg : lhs.Function->args()) { - { - auto foundLHS = lhs.Arguments.find(&arg); - assert(foundLHS != lhs.Arguments.end()); - auto foundRHS = rhs.Arguments.find(&arg); - assert(foundRHS != rhs.Arguments.end()); - if (foundLHS->second < foundRHS->second) - return true; - if (foundRHS->second < foundLHS->second) - return false; - } - - { - auto foundLHS = lhs.KnownValues.find(&arg); - assert(foundLHS != lhs.KnownValues.end()); - auto foundRHS = rhs.KnownValues.find(&arg); - assert(foundRHS != rhs.KnownValues.end()); - if (foundLHS->second < foundRHS->second) - return true; - if (foundRHS->second < foundLHS->second) - return false; - } - } - // equal; - return false; -} - -class TypeAnalyzer; -class TypeAnalysis; - -/// A holder class representing the results of running TypeAnalysis -/// on a given function -class TypeResults { -public: - TypeAnalyzer *analyzer; - -public: - TypeResults(std::nullptr_t); - TypeResults(TypeAnalyzer &analyzer); - ConcreteType intType(size_t num, llvm::Value *val, bool errIfNotFound = true, - bool pointerIntSame = false) const; - llvm::Type *addingType(size_t num, llvm::Value *val, size_t start = 0) const; - - /// Returns whether in the first num bytes there is pointer, int, float, or - /// none If pointerIntSame is set to true, then consider either as the same - /// (and thus mergable) - ConcreteType firstPointer(size_t num, llvm::Value *val, llvm::Instruction *I, - bool errIfNotFound = true, - bool pointerIntSame = false) const; - - /// The TypeTree of a particular Value - TypeTree query(llvm::Value *val) const; - - /// Whether any part of the top level register can contain a float - /// e.g. { i64, float } can contain a float, but { i64, i8* } would not. - // Of course, here we compute with type analysis rather than llvm type - // The flag `anythingIsFloat` specifies whether an anything should - // be considered a float. - bool anyFloat(llvm::Value *val, bool anythingIsFloat = true) const; - - /// Whether any part of the top level register can contain a pointer - /// e.g. { i64, i8* } can contain a pointer, but { i64, float } would not. - // Of course, here we compute with type analysis rather than llvm type - bool anyPointer(llvm::Value *val) const; - - /// The TypeInfo calling convention - FnTypeInfo getAnalyzedTypeInfo() const; - - /// The Type of the return - TypeTree getReturnAnalysis() const; - - /// Prints all known information - void dump(llvm::raw_ostream &ss = llvm::errs()) const; - - /// The set of values val will take on during this program - std::set knownIntegralValues(llvm::Value *val) const; - - FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn) const; - - llvm::Function *getFunction() const; -}; - -/// Helper class that computes the fixed-point type results of a given function -class TypeAnalyzer : public llvm::InstVisitor { -public: - /// Cache of metadata indices, for faster printing. - /// Only initialized if EnzymePrintType is true - std::shared_ptr MST; - - /// List of value's which should be re-analyzed now with new information - llvm::SetVector> workList; - - const llvm::SmallPtrSet notForAnalysis; - -private: - /// Tell TypeAnalyzer to reanalyze this value - void addToWorkList(llvm::Value *val); - - /// Map of Value to known integer constants that it will take on - std::map> intseen; - - std::map> mriseen; - bool mustRemainInteger(llvm::Value *val, bool *returned = nullptr); - -public: - /// Calling context - const FnTypeInfo fntypeinfo; - - /// Calling TypeAnalysis to be used in the case of calls to other - /// functions - TypeAnalysis &interprocedural; - - /// Directionality of checks - uint8_t direction; - - /// Whether an inconsistent update has been found - /// This will only be set when direction != Both, erring otherwise - bool Invalid; - - bool PHIRecur; - - // propagate from instruction to operand - static constexpr uint8_t UP = 1; - // propagate from operand to instruction - static constexpr uint8_t DOWN = 2; - static constexpr uint8_t BOTH = UP | DOWN; - - /// Intermediate conservative, but correct Type analysis results - std::map analysis; - - llvm::TargetLibraryInfo &TLI; - llvm::DominatorTree &DT; - llvm::PostDominatorTree &PDT; - - llvm::LoopInfo &LI; - llvm::ScalarEvolution &SE; - - FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn); - - TypeAnalyzer(TypeAnalysis &TA); - - TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA, - uint8_t direction = BOTH); - - TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA, - const llvm::SmallPtrSetImpl ¬ForAnalysis, - const TypeAnalyzer &Prev, uint8_t direction = BOTH, - bool PHIRecur = false); - - /// Get the current results for a given value - TypeTree getAnalysis(llvm::Value *Val); - - /// Add additional information to the Type info of val, readding it to the - /// work queue as necessary - void updateAnalysis(llvm::Value *val, BaseType data, llvm::Value *origin); - void updateAnalysis(llvm::Value *val, ConcreteType data, llvm::Value *origin); - void updateAnalysis(llvm::Value *val, TypeTree data, llvm::Value *origin); - - /// Analyze type info given by the arguments, possibly adding to work queue - void prepareArgs(); - - /// Analyze type info given by the TBAA, possibly adding to work queue - void considerTBAA(); - - /// Parse the debug info generated by rustc and retrieve useful type info if - /// possible - void considerRustDebugInfo(); - - /// Run the interprocedural type analysis starting from this function - void run(); - - /// Hypothesize that undefined phi's are integers and try to prove - /// that they are really integral - void runPHIHypotheses(); - - void visitValue(llvm::Value &val); - - void visitConstantExpr(llvm::ConstantExpr &CE); - - void visitCmpInst(llvm::CmpInst &I); - - void visitAllocaInst(llvm::AllocaInst &I); - - void visitLoadInst(llvm::LoadInst &I); - - void visitStoreInst(llvm::StoreInst &I); - - void visitGetElementPtrInst(llvm::GetElementPtrInst &gep); - - void visitGEPOperator(llvm::GEPOperator &gep); - - void visitPHINode(llvm::PHINode &phi); - - void visitTruncInst(llvm::TruncInst &I); - - void visitZExtInst(llvm::ZExtInst &I); - - void visitSExtInst(llvm::SExtInst &I); - - void visitAddrSpaceCastInst(llvm::AddrSpaceCastInst &I); - - void visitFPExtInst(llvm::FPExtInst &I); - - void visitFPTruncInst(llvm::FPTruncInst &I); - - void visitFPToUIInst(llvm::FPToUIInst &I); - - void visitFPToSIInst(llvm::FPToSIInst &I); - - void visitUIToFPInst(llvm::UIToFPInst &I); - - void visitSIToFPInst(llvm::SIToFPInst &I); - - void visitPtrToIntInst(llvm::PtrToIntInst &I); - - void visitIntToPtrInst(llvm::IntToPtrInst &I); - - void visitBitCastInst(llvm::BitCastInst &I); - -#if LLVM_VERSION_MAJOR >= 10 - void visitFreezeInst(llvm::FreezeInst &I); -#endif - - void visitSelectInst(llvm::SelectInst &I); - - void visitExtractElementInst(llvm::ExtractElementInst &I); - - void visitInsertElementInst(llvm::InsertElementInst &I); - - void visitShuffleVectorInst(llvm::ShuffleVectorInst &I); - - void visitExtractValueInst(llvm::ExtractValueInst &I); - - void visitInsertValueInst(llvm::InsertValueInst &I); - - void visitAtomicRMWInst(llvm::AtomicRMWInst &I); - - void visitBinaryOperator(llvm::BinaryOperator &I); - void visitBinaryOperation(const llvm::DataLayout &DL, llvm::Type *T, - llvm::Instruction::BinaryOps, llvm::Value *Args[2], - TypeTree &Ret, TypeTree &LHS, TypeTree &RHS, - llvm::Instruction *I); - - void visitIPOCall(llvm::CallBase &call, llvm::Function &fn); - - void visitCallBase(llvm::CallBase &call); - - void visitMemTransferInst(llvm::MemTransferInst &MTI); - void visitMemTransferCommon(llvm::CallBase &MTI); - - void visitIntrinsicInst(llvm::IntrinsicInst &II); - - TypeTree getReturnAnalysis(); - - void dump(llvm::raw_ostream &ss = llvm::errs()); - - std::set knownIntegralValues(llvm::Value *val); - - // TODO handle fneg on LLVM 10+ -}; - -/// Full interprocedural TypeAnalysis -class TypeAnalysis { -public: - llvm::FunctionAnalysisManager &FAM; - TypeAnalysis(llvm::FunctionAnalysisManager &FAM) : FAM(FAM) {} - /// Map of custom function call handlers - llvm::StringMap< - std::function /*argTrees*/, - llvm::ArrayRef> /*knownValues*/, - llvm::CallBase * /*call*/, TypeAnalyzer *)>> - CustomRules; - - /// Map of possible query states to TypeAnalyzer intermediate results - std::map> analyzedFunctions; - - /// Analyze a particular function, returning the results - TypeResults analyzeFunction(const FnTypeInfo &fn); - - /// Clear existing analyses - void clear(); -}; - -TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, - bool intIsPointer = true); -FnTypeInfo preventTypeAnalysisLoops(const FnTypeInfo &oldTypeInfo_, - llvm::Function *todiff); -#endif diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp deleted file mode 100644 index 38cde29e38d5..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp +++ /dev/null @@ -1,192 +0,0 @@ -//===- TypeAnalysisPrinter.cpp - Printer utility pass for Type Analysis ---===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains a utility LLVM pass for printing derived Type Analysis -// results of a given function. -// -//===----------------------------------------------------------------------===// -#include - -#if LLVM_VERSION_MAJOR >= 16 -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "llvm/ADT/SmallVector.h" - -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Metadata.h" - -#include "llvm/Support/Debug.h" -#include "llvm/Transforms/Scalar.h" - -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/ScalarEvolution.h" - -#include "llvm/Support/CommandLine.h" - -#include "../FunctionUtils.h" -#include "../Utils.h" -#include "TypeAnalysis.h" -#include "TypeAnalysisPrinter.h" - -using namespace llvm; -#ifdef DEBUG_TYPE -#undef DEBUG_TYPE -#endif -#define DEBUG_TYPE "type-analysis-results" - -/// Function ActivityAnalysis will be starting its run from -llvm::cl::opt - FunctionToAnalyze("type-analysis-func", cl::init(""), cl::Hidden, - cl::desc("Which function to analyze/print")); - -namespace { -bool printTypeAnalyses(llvm::Function &F) { - - if (F.getName() != FunctionToAnalyze) - return /*changed*/ false; - - FnTypeInfo type_args(&F); - for (auto &a : type_args.Function->args()) { - TypeTree dt; - if (a.getType()->isFPOrFPVectorTy()) { - dt = ConcreteType(a.getType()->getScalarType()); - } else if (a.getType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR < 17 -#if LLVM_VERSION_MAJOR >= 15 - if (F.getContext().supportsTypedPointers()) { -#endif - auto et = cast(a.getType())->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); - } -#if LLVM_VERSION_MAJOR >= 15 - } -#endif -#endif - dt.insert({}, BaseType::Pointer); - } else if (a.getType()->isIntOrIntVectorTy()) { - dt = ConcreteType(BaseType::Integer); - } - type_args.Arguments.insert( - std::pair(&a, dt.Only(-1, nullptr))); - // TODO note that here we do NOT propagate constants in type info (and - // should consider whether we should) - type_args.KnownValues.insert( - std::pair>(&a, {})); - } - - TypeTree dt; - if (F.getReturnType()->isFPOrFPVectorTy()) { - dt = ConcreteType(F.getReturnType()->getScalarType()); - } else if (F.getReturnType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR < 17 -#if LLVM_VERSION_MAJOR >= 15 - if (F.getContext().supportsTypedPointers()) { -#endif - auto et = cast(F.getReturnType())->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); - } -#if LLVM_VERSION_MAJOR >= 15 - } -#endif -#endif - dt.insert({}, BaseType::Pointer); - } else if (F.getReturnType()->isIntOrIntVectorTy()) { - dt = ConcreteType(BaseType::Integer); - } - type_args.Return = dt.Only(-1, nullptr); - PreProcessCache PPC; - TypeAnalysis TA(PPC.FAM); - TA.analyzeFunction(type_args); - for (Function &f : *F.getParent()) { - - for (auto &analysis : TA.analyzedFunctions) { - if (analysis.first.Function != &f) - continue; - auto &ta = *analysis.second; - llvm::outs() << f.getName() << " - " << analysis.first.Return.str() - << " |"; - - for (auto &a : f.args()) { - llvm::outs() << analysis.first.Arguments.find(&a)->second.str() << ":" - << to_string(analysis.first.KnownValues.find(&a)->second) - << " "; - } - llvm::outs() << "\n"; - - for (auto &a : f.args()) { - llvm::outs() << a << ": " << ta.getAnalysis(&a).str() << "\n"; - } - for (auto &BB : f) { - llvm::outs() << BB.getName() << "\n"; - for (auto &I : BB) { - llvm::outs() << I << ": " << ta.getAnalysis(&I).str() << "\n"; - } - } - } - } - return /*changed*/ false; -} -class TypeAnalysisPrinter final : public FunctionPass { -public: - static char ID; - TypeAnalysisPrinter() : FunctionPass(ID) {} - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired(); - } - - bool runOnFunction(Function &F) override { return printTypeAnalyses(F); } -}; - -} // namespace - -char TypeAnalysisPrinter::ID = 0; - -static RegisterPass X("print-type-analysis", - "Print Type Analysis Results"); - -TypeAnalysisPrinterNewPM::Result -TypeAnalysisPrinterNewPM::run(llvm::Module &M, - llvm::ModuleAnalysisManager &MAM) { - bool changed = false; - for (auto &F : M) - changed |= printTypeAnalyses(F); - return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); -} -llvm::AnalysisKey TypeAnalysisPrinterNewPM::Key; diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.h b/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.h deleted file mode 100644 index d7f7ff3cf2f8..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.h +++ /dev/null @@ -1,52 +0,0 @@ -//===- TypeAnalysisPrinter.h - Printer utility pass for Type Analysis -----===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains a utility LLVM pass for printing derived Type Analysis -// results of a given function. -// -//===----------------------------------------------------------------------===// - -#ifndef ENZYME_TYPE_ANALYSIS_TYPE_ANALYSIS_PRINTER_H -#define ENZYME_TYPE_ANALYSIS_TYPE_ANALYSIS_PRINTER_H - -#include "llvm/IR/PassManager.h" -#include "llvm/Passes/PassPlugin.h" - -namespace llvm { -class FunctionPass; -} - -class TypeAnalysisPrinterNewPM final - : public llvm::AnalysisInfoMixin { - friend struct llvm::AnalysisInfoMixin; - -private: - static llvm::AnalysisKey Key; - -public: - using Result = llvm::PreservedAnalyses; - TypeAnalysisPrinterNewPM() {} - - Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM); - - static bool isRequired() { return true; } -}; - -#endif // ENZYME_TYPE_ANALYSIS_TYPE_ANALYSIS_PRINTER_H diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.cpp b/enzyme/Enzyme/TypeAnalysis/TypeTree.cpp deleted file mode 100644 index c1618c6e49d3..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.cpp +++ /dev/null @@ -1,48 +0,0 @@ -//===- TypeTree.cpp - Implementation of Type Analysis Type Trees-----------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the implementation TypeTrees -- a class -// representing all of the underlying types of a particular LLVM value. This -// consists of a map of memory offsets to an underlying ConcreteType. This -// permits TypeTrees to represent distinct underlying types at different -// locations. Presently, TypeTree's have both a fixed depth of memory lookups -// and a maximum offset to ensure that Type Analysis eventually terminates. -// In the future this should be modified to better represent recursive types -// rather than limiting the depth. -// -//===----------------------------------------------------------------------===// -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DerivedTypes.h" - -#include "llvm/Support/CommandLine.h" - -#include "TypeTree.h" - -using namespace llvm; - -extern "C" { -/// Maximum offset for type trees to keep -llvm::cl::opt MaxTypeOffset("enzyme-max-type-offset", cl::init(500), - cl::Hidden, - cl::desc("Maximum type tree offset")); -llvm::cl::opt EnzymeTypeWarning("enzyme-type-warning", cl::init(true), - cl::Hidden, - cl::desc("Print Type Depth Warning")); -} diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h deleted file mode 100644 index 6c1a6baa8a17..000000000000 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ /dev/null @@ -1,1450 +0,0 @@ -//===- TypeTree.cpp - Declaration of Type Analysis Type Trees -----------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains the declaration of TypeTrees -- a class -// representing all of the underlying types of a particular LLVM value. This -// consists of a map of memory offsets to an underlying ConcreteType. This -// permits TypeTrees to represent distinct underlying types at different -// locations. Presently, TypeTree's have both a fixed depth of memory lookups -// and a maximum offset to ensure that Type Analysis eventually terminates. -// In the future this should be modified to better represent recursive types -// rather than limiting the depth. -// -//===----------------------------------------------------------------------===// -#ifndef ENZYME_TYPE_ANALYSIS_TYPE_TREE_H -#define ENZYME_TYPE_ANALYSIS_TYPE_TREE_H 1 - -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/raw_ostream.h" -#include -#include -#include -#include - -#include "../Utils.h" -#include "BaseType.h" -#include "ConcreteType.h" - -/// Maximum offset for type trees to keep -extern "C" { -extern llvm::cl::opt MaxTypeOffset; -extern llvm::cl::opt EnzymeTypeWarning; -extern llvm::cl::opt EnzymeMaxTypeDepth; -} - -/// Helper function to print a vector of ints to a string -static inline std::string to_string(const std::vector x) { - std::string out = "["; - for (unsigned i = 0; i < x.size(); ++i) { - if (i != 0) - out += ","; - out += std::to_string(x[i]); - } - out += "]"; - return out; -} - -class TypeTree; - -typedef std::shared_ptr TypeResult; -typedef std::map, ConcreteType> ConcreteTypeMapType; -typedef std::map, const TypeResult> TypeTreeMapType; - -/// Class representing the underlying types of values as -/// sequences of offsets to a ConcreteType -class TypeTree : public std::enable_shared_from_this { -private: - // mapping of known indices to type if one exists - ConcreteTypeMapType mapping; - std::vector minIndices; - -public: - TypeTree() {} - TypeTree(ConcreteType dat) { - if (dat != ConcreteType(BaseType::Unknown)) { - mapping.insert(std::pair, ConcreteType>({}, dat)); - } - } - - static TypeTree parse(llvm::StringRef str, llvm::LLVMContext &ctx) { - using namespace llvm; - assert(str[0] == '{'); - str = str.substr(1); - - TypeTree Result; - while (true) { - while (str[0] == ' ') - str = str.substr(1); - if (str[0] == '}') - break; - - assert(str[0] == '['); - str = str.substr(1); - - std::vector idxs; - while (true) { - while (str[0] == ' ') - str = str.substr(1); - if (str[0] == ']') { - str = str.substr(1); - break; - } - - int idx; - bool failed = str.consumeInteger(10, idx); - (void)failed; - assert(!failed); - idxs.push_back(idx); - - while (str[0] == ' ') - str = str.substr(1); - - if (str[0] == ',') { - str = str.substr(1); - } - } - - while (str[0] == ' ') - str = str.substr(1); - - assert(str[0] == ':'); - str = str.substr(1); - - while (str[0] == ' ') - str = str.substr(1); - - auto endval = str.find(','); - auto endval2 = str.find('}'); - auto endval3 = str.find(' '); - - if (endval2 != StringRef::npos && - (endval == StringRef::npos || endval2 < endval)) - endval = endval2; - if (endval3 != StringRef::npos && - (endval == StringRef::npos || endval3 < endval)) - endval = endval3; - assert(endval != StringRef::npos); - - auto tystr = str.substr(0, endval); - str = str.substr(endval); - - ConcreteType CT(tystr, ctx); - Result.mapping.emplace(idxs, CT); - if (Result.minIndices.size() < idxs.size()) { - for (size_t i = Result.minIndices.size(), end = idxs.size(); i < end; - ++i) { - Result.minIndices.push_back(idxs[i]); - } - } - for (size_t i = 0, end = idxs.size(); i < end; ++i) { - if (idxs[i] < Result.minIndices[i]) - Result.minIndices[i] = idxs[i]; - } - - while (str[0] == ' ') - str = str.substr(1); - - if (str[0] == ',') { - str = str.substr(1); - } - } - - return Result; - } - - /// Utility helper to lookup the mapping - const ConcreteTypeMapType &getMapping() const { return mapping; } - - /// Lookup the underlying ConcreteType at a given offset sequence - /// or Unknown if none exists - ConcreteType operator[](const std::vector Seq) const { - auto Found0 = mapping.find(Seq); - if (Found0 != mapping.end()) - return Found0->second; - size_t Len = Seq.size(); - if (Len == 0) - return BaseType::Unknown; - - std::vector> todo[2]; - todo[0].push_back({}); - int parity = 0; - for (size_t i = 0, Len = Seq.size(); i < Len - 1; ++i) { - for (auto prev : todo[parity]) { - prev.push_back(-1); - if (mapping.find(prev) != mapping.end()) - todo[1 - parity].push_back(prev); - if (Seq[i] != -1) { - prev.back() = Seq[i]; - if (mapping.find(prev) != mapping.end()) - todo[1 - parity].push_back(prev); - } - } - todo[parity].clear(); - parity = 1 - parity; - } - - size_t i = Len - 1; - for (auto prev : todo[parity]) { - prev.push_back(-1); - auto Found = mapping.find(prev); - if (Found != mapping.end()) - return Found->second; - if (Seq[i] != -1) { - prev.back() = Seq[i]; - Found = mapping.find(prev); - if (Found != mapping.end()) - return Found->second; - } - } - return BaseType::Unknown; - } - - // Return true if this type tree is fully known (i.e. there - // is no more information which could be added). - bool IsFullyDetermined() const { - std::vector offsets = {-1}; - while (1) { - auto found = mapping.find(offsets); - if (found == mapping.end()) - return false; - if (found->second != BaseType::Pointer) - return true; - offsets.push_back(-1); - } - } - - /// Return if changed - bool insert(const std::vector Seq, ConcreteType CT, - bool PointerIntSame = false) { - size_t SeqSize = Seq.size(); - if (SeqSize > EnzymeMaxTypeDepth) { - if (EnzymeTypeWarning) { - if (CustomErrorHandler) { - CustomErrorHandler("TypeAnalysisDepthLimit", nullptr, - ErrorType::TypeDepthExceeded, this, nullptr, - nullptr); - } else - llvm::errs() << "not handling more than " << EnzymeMaxTypeDepth - << " pointer lookups deep dt:" << str() - << " adding v: " << to_string(Seq) << ": " << CT.str() - << "\n"; - } - return false; - } - if (SeqSize == 0) { - mapping.insert(std::pair, ConcreteType>(Seq, CT)); - return true; - } - - // check types at lower pointer offsets are either pointer or - // anything. Don't insert into an anything - { - std::vector tmp(Seq); - while (tmp.size() > 0) { - tmp.erase(tmp.end() - 1); - auto found = mapping.find(tmp); - if (found != mapping.end()) { - if (found->second == BaseType::Anything) - return false; - if (found->second != BaseType::Pointer) { - llvm::errs() << "FAILED CT: " << str() - << " adding Seq: " << to_string(Seq) << ": " - << CT.str() << "\n"; - } - assert(found->second == BaseType::Pointer); - } - } - } - - bool changed = false; - // Check if there is an existing match, e.g. [-1, -1, -1] and inserting - // [-1, 8, -1] - { - for (const auto &pair : llvm::make_early_inc_range(mapping)) { - if (pair.first.size() == SeqSize) { - // Whether the the inserted val (e.g. [-1, 0] or [0, 0]) is at least - // as general as the existing map val (e.g. [0, 0]). - bool newMoreGeneralThanOld = true; - // Whether the the existing val (e.g. [-1, 0] or [0, 0]) is at least - // as general as the inserted map val (e.g. [0, 0]). - bool oldMoreGeneralThanNew = true; - for (unsigned i = 0; i < SeqSize; i++) { - if (pair.first[i] == Seq[i]) - continue; - if (Seq[i] == -1) { - oldMoreGeneralThanNew = false; - } else if (pair.first[i] == -1) { - newMoreGeneralThanOld = false; - } else { - oldMoreGeneralThanNew = false; - newMoreGeneralThanOld = false; - break; - } - } - - if (oldMoreGeneralThanNew) { - // Inserting an existing or less general version - if (CT == pair.second) - return false; - - // Inserting an existing or less general version (with pointer-int - // equivalence) - if (PointerIntSame) - if ((CT == BaseType::Pointer && - pair.second == BaseType::Integer) || - (CT == BaseType::Integer && pair.second == BaseType::Pointer)) - return false; - - // Inserting into an anything. Since from above we know this is not - // an anything, the inserted value contains no new information - if (pair.second == BaseType::Anything) - return false; - - // Inserting say a [0]:anything into a [-1]:Float - if (CT == BaseType::Anything) - continue; - - // Otherwise, inserting a non-equivalent pair into a more general - // slot. This is invalid. - llvm::errs() << "inserting into : " << str() << " with " - << to_string(Seq) << " of " << CT.str() << "\n"; - llvm_unreachable("illegal insertion"); - } else if (newMoreGeneralThanOld) { - // This new is strictly more general than the old. If they were - // equivalent, the case above would have been hit. - - if (CT == BaseType::Anything || CT == pair.second) { - // previous equivalent values or values overwritten by - // an anything are removed - changed = true; - mapping.erase(pair.first); - continue; - } - - // Inserting an existing or less general version (with pointer-int - // equivalence) - if (PointerIntSame) - if ((CT == BaseType::Pointer && - pair.second == BaseType::Integer) || - (CT == BaseType::Integer && - pair.second == BaseType::Pointer)) { - changed = true; - mapping.erase(pair.first); - continue; - } - - // Keep lingering anythings if not being overwritten, even if this - // (e.g. Float) applies to more locations. Therefore it is legal to - // have [-1]:Float, [8]:Anything - if (CT != BaseType::Anything && pair.second == BaseType::Anything) - continue; - - // Otherwise, inserting a more general non-equivalent pair. This is - // invalid. - llvm::errs() << "inserting into : " << str() << " with " - << to_string(Seq) << " of " << CT.str() << "\n"; - llvm_unreachable("illegal insertion"); - } - } - } - } - - bool possibleDeletion = false; - size_t minLen = - (minIndices.size() <= SeqSize) ? minIndices.size() : SeqSize; - for (size_t i = 0; i < minLen; i++) { - if (minIndices[i] > Seq[i]) { - if (minIndices[i] > MaxTypeOffset) - possibleDeletion = true; - minIndices[i] = Seq[i]; - } - } - - if (minIndices.size() < SeqSize) { - for (size_t i = minIndices.size(), end = SeqSize; i < end; ++i) { - minIndices.push_back(Seq[i]); - } - } - - if (possibleDeletion) { - for (const auto &pair : llvm::make_early_inc_range(mapping)) { - size_t i = 0; - bool mustKeep = false; - bool considerErase = false; - for (int val : pair.first) { - if (val > MaxTypeOffset) { - if (val == minIndices[i]) { - mustKeep = true; - break; - } - considerErase = true; - } - ++i; - } - if (!mustKeep && considerErase) { - mapping.erase(pair.first); - changed = true; - } - } - } - - size_t i = 0; - bool keep = false; - bool considerErase = false; - for (auto val : Seq) { - if (val > MaxTypeOffset) { - if (val == minIndices[i]) { - keep = true; - break; - } - considerErase = true; - } - i++; - } - if (considerErase && !keep) - return changed; - mapping.insert(std::pair, ConcreteType>(Seq, CT)); - return true; - } - - /// How this TypeTree compares with another - bool operator<(const TypeTree &vd) const { return mapping < vd.mapping; } - - /// Whether this TypeTree contains any information - bool isKnown() const { -#ifndef NDEBUG - for (const auto &pair : mapping) { - // we should assert here as we shouldn't keep any unknown maps for - // efficiency - assert(pair.second.isKnown()); - } -#endif - return mapping.size() != 0; - } - - /// Whether this TypeTree knows any non-pointer information - bool isKnownPastPointer() const { - for (auto &pair : mapping) { - // we should assert here as we shouldn't keep any unknown maps for - // efficiency - assert(pair.second.isKnown()); - if (pair.first.size() == 0) { - assert(pair.second == BaseType::Pointer || - pair.second == BaseType::Anything); - continue; - } - return true; - } - return false; - } - - /// Select only the Integer ConcreteTypes - TypeTree JustInt() const { - TypeTree vd; - for (auto &pair : mapping) { - if (pair.second == BaseType::Integer) { - vd.insert(pair.first, pair.second); - } - } - - return vd; - } - - /// Prepend an offset to all mappings - TypeTree Only(int Off, llvm::Instruction *orig) const { - TypeTree Result; - Result.minIndices.reserve(1 + minIndices.size()); - Result.minIndices.push_back(Off); - for (auto midx : minIndices) - Result.minIndices.push_back(midx); - - if (Result.minIndices.size() > EnzymeMaxTypeDepth) { - Result.minIndices.pop_back(); - if (EnzymeTypeWarning) { - if (CustomErrorHandler) { - CustomErrorHandler("TypeAnalysisDepthLimit", wrap(orig), - ErrorType::TypeDepthExceeded, this, nullptr, - nullptr); - } else if (orig) { - EmitWarning("TypeAnalysisDepthLimit", *orig, *orig, - " not handling more than ", EnzymeMaxTypeDepth, - " pointer lookups deep dt: ", str(), " only(", Off, ")"); - } else { - llvm::errs() << "not handling more than " << EnzymeMaxTypeDepth - << " pointer lookups deep dt:" << str() << " only(" - << Off << "): " - << "\n"; - } - } - } - - for (const auto &pair : mapping) { - if (pair.first.size() == EnzymeMaxTypeDepth) - continue; - std::vector Vec; - Vec.reserve(pair.first.size() + 1); - Vec.push_back(Off); - for (auto Val : pair.first) - Vec.push_back(Val); - Result.mapping.insert( - std::pair, ConcreteType>(Vec, pair.second)); - } - return Result; - } - - /// Peel off the outermost index at offset 0 - TypeTree Data0() const { - TypeTree Result; - - for (const auto &pair : mapping) { - if (pair.first.size() == 0) { - llvm::errs() << str() << "\n"; - } - assert(pair.first.size() != 0); - - if (pair.first[0] == -1) { - std::vector next(pair.first.begin() + 1, pair.first.end()); - Result.mapping.insert( - std::pair, ConcreteType>(next, pair.second)); - for (size_t i = 0, Len = next.size(); i < Len; ++i) { - if (i == Result.minIndices.size()) - Result.minIndices.push_back(next[i]); - else if (next[i] < Result.minIndices[i]) - Result.minIndices[i] = next[i]; - } - } - } - for (const auto &pair : mapping) { - if (pair.first[0] == 0) { - std::vector next(pair.first.begin() + 1, pair.first.end()); - // We do insertion like this to force an error - // on the orIn operation if there is an incompatible - // merge. The insert operation does not error. - Result.orIn(next, pair.second); - } - } - - return Result; - } - - /// Optimized version of Data0()[{}] - ConcreteType Inner0() const { - ConcreteType CT = operator[]({-1}); - CT |= operator[]({0}); - return CT; - } - - /// Remove any mappings in the range [start, end) or [len, inf) - /// This function has special handling for -1's - TypeTree Clear(size_t start, size_t end, size_t len) const { - TypeTree Result; - - // Note that below do insertion with the orIn operator - // to force an error if there is an incompatible - // merge. The insert operation does not error. - - for (const auto &pair : mapping) { - assert(pair.first.size() != 0); - - if (pair.first[0] == -1) { - // For "all index" calculations, explicitly - // add mappings for regions in range - auto next = pair.first; - for (size_t i = 0; i < start; ++i) { - next[0] = i; - Result.orIn(next, pair.second); - } - for (size_t i = end; i < len; ++i) { - next[0] = i; - Result.orIn(next, pair.second); - } - } else if ((size_t)pair.first[0] < start || - ((size_t)pair.first[0] >= end && - (size_t)pair.first[0] < len)) { - // Otherwise simply check that the given offset is in range - - Result.insert(pair.first, pair.second); - } - } - - // TODO canonicalize this - return Result; - } - - /// Select all submappings whose first index is in range [0, len) and remove - /// the first index. This is the inverse of the `Only` operation - TypeTree Lookup(size_t len, const llvm::DataLayout &dl) const { - - // Map of indices[1:] => ( End => possible Index[0] ) - std::map, std::map>> staging; - - for (const auto &pair : mapping) { - assert(pair.first.size() != 0); - - // Pointer is at offset 0 from this object - if (pair.first[0] != 0 && pair.first[0] != -1) - continue; - - if (pair.first.size() == 1) { - assert(pair.second == ConcreteType(BaseType::Pointer) || - pair.second == ConcreteType(BaseType::Anything)); - continue; - } - - if (pair.first[1] == -1) { - } else { - if ((size_t)pair.first[1] >= len) - continue; - } - - std::vector next(pair.first.begin() + 2, pair.first.end()); - - staging[next][pair.second].insert(pair.first[1]); - } - - TypeTree Result; - for (auto &pair : staging) { - auto &pnext = pair.first; - for (auto &pair2 : pair.second) { - auto dt = pair2.first; - const auto &set = pair2.second; - - bool legalCombine = set.count(-1); - - // See if we can canonicalize the outermost index into a -1 - if (!legalCombine) { - size_t chunk = 1; - // Implicit pointer - if (pnext.size() > 0) { - chunk = dl.getPointerSizeInBits() / 8; - } else { - if (auto flt = dt.isFloat()) { - chunk = dl.getTypeSizeInBits(flt) / 8; - } else if (dt == BaseType::Pointer) { - chunk = dl.getPointerSizeInBits() / 8; - } - } - - legalCombine = true; - for (size_t i = 0; i < len; i += chunk) { - if (!set.count(i)) { - legalCombine = false; - break; - } - } - } - - std::vector next; - next.reserve(pnext.size() + 1); - next.push_back(-1); - for (auto v : pnext) - next.push_back(v); - - if (legalCombine) { - Result.insert(next, dt, /*intsAreLegalPointerSub*/ true); - } else { - for (auto e : set) { - next[0] = e; - Result.insert(next, dt); - } - } - } - } - - return Result; - } - - /// Given that this tree represents something of at most size len, - /// canonicalize this, creating -1's where possible - void CanonicalizeInPlace(size_t len, const llvm::DataLayout &dl) { - bool canonicalized = true; - for (const auto &pair : mapping) { - assert(pair.first.size() != 0); - if (pair.first[0] != -1) { - canonicalized = false; - break; - } - } - if (canonicalized) - return; - - // Map of indices[1:] => ( End => possible Index[0] ) - std::map, std::map>> - staging; - - for (const auto &pair : mapping) { - - std::vector next(pair.first.begin() + 1, pair.first.end()); - if (pair.first[0] != -1) { - if ((size_t)pair.first[0] >= len) { - llvm::errs() << str() << "\n"; - llvm::errs() << " canonicalizing " << len << "\n"; - llvm::report_fatal_error("Canonicalization failed"); - } - } - staging[next][pair.second].insert(pair.first[0]); - } - - // TypeTree mappings which did not get combined - std::map, ConcreteType> unCombinedToAdd; - - // TypeTree mappings which did get combined into an outer -1 - std::map, ConcreteType> combinedToAdd; - - for (const auto &pair : staging) { - auto &pnext = pair.first; - for (const auto &pair2 : pair.second) { - auto dt = pair2.first; - const auto &set = pair2.second; - - bool legalCombine = false; - - // See if we can canonicalize the outermost index into a -1 - if (!set.count(-1)) { - size_t chunk = 1; - if (pnext.size() > 0) { - chunk = dl.getPointerSizeInBits() / 8; - } else { - if (auto flt = dt.isFloat()) { - chunk = dl.getTypeSizeInBits(flt) / 8; - } else if (dt == BaseType::Pointer) { - chunk = dl.getPointerSizeInBits() / 8; - } - } - - legalCombine = true; - for (size_t i = 0; i < len; i += chunk) { - if (!set.count(i)) { - legalCombine = false; - break; - } - } - } - - std::vector next; - next.reserve(pnext.size() + 1); - next.push_back(-1); - for (auto v : pnext) - next.push_back(v); - - if (legalCombine) { - combinedToAdd.emplace(next, dt); - } else { - for (auto e : set) { - next[0] = e; - unCombinedToAdd.emplace(next, dt); - } - } - } - } - - // If we combined nothing, just return since there are no - // changes. - if (combinedToAdd.size() == 0) { - return; - } - - // Non-combined ones do not conflict, since they were already in - // a TT which we can assume contained no conflicts. - mapping = std::move(unCombinedToAdd); - if (minIndices.size() > 0) { - minIndices[0] = -1; - } - - // Fusing several terms into a minus one can create a conflict - // if the prior minus one was already in the map - // time, or also generated by fusion. - // E.g. {-1:Anything, [0]:Pointer} on 8 -> create a [-1]:Pointer - // which conflicts - // Alternatively [-1,-1,-1]:Pointer, and generated a [-1,0,-1] fusion - for (const auto &pair : combinedToAdd) { - insert(pair.first, pair.second); - } - - return; - } - - /// Keep only pointers (or anything's) to a repeated value (represented by -1) - TypeTree KeepMinusOne(bool &legal) const { - TypeTree dat; - - for (const auto &pair : mapping) { - - assert(pair.first.size() != 0); - - // Pointer is at offset 0 from this object - if (pair.first[0] != 0 && pair.first[0] != -1) - continue; - - if (pair.first.size() == 1) { - if (pair.second == BaseType::Pointer || - pair.second == BaseType::Anything) { - dat.insert(pair.first, pair.second); - continue; - } - legal = false; - break; - } - - if (pair.first[1] == -1) { - dat.insert(pair.first, pair.second); - } - } - - return dat; - } - - llvm::Type *IsAllFloat(const size_t size, const llvm::DataLayout &dl) const { - auto m1 = TypeTree::operator[]({-1}); - if (auto FT = m1.isFloat()) - return FT; - - auto m0 = TypeTree::operator[]({0}); - - if (auto flt = m0.isFloat()) { - size_t chunk = dl.getTypeSizeInBits(flt) / 8; - for (size_t i = chunk; i < size; i += chunk) { - auto mx = TypeTree::operator[]({(int)i}); - if (auto f2 = mx.isFloat()) { - if (f2 != flt) - return nullptr; - } else - return nullptr; - } - return flt; - } else { - return nullptr; - } - } - - /// Replace mappings in the range in [offset, offset+maxSize] with those in - // [addOffset, addOffset + maxSize]. In other words, select all mappings in - // [offset, offset+maxSize] then add `addOffset` - TypeTree ShiftIndices(const llvm::DataLayout &dl, const int offset, - const int maxSize, size_t addOffset = 0) const { - - // If we have no terms 1+ layer deep return the current result as a shift - // won't change anything. This also makes the latercode simpler as it - // can assume at least a first index exists. - if (minIndices.size() == 0) - return *this; - - // If we have no size in return, simply return an empty type tree. Again - // this simplifies later code which can assume that a minus one expantion - // will always result in an added variable (which would not be the case - // on a size == 0). - if (maxSize == 0) - return TypeTree(); - - TypeTree Result; - - // The normal orIn / insert methods do collision checking, which is slow - // (and presently O(n)). This is because an expansion of a -1 which could - // conflict with a fixed value. Consider calling this - // ShiftIndicies(offset=0, maxSize=2, addOffset=0, tt={[-1]:Integer, - // [1]:Anything}) the -1 would expand to [0]:Int, [1]:Int, which would need - // to be merged with [1]:Anything - // - // The only possible values which can cause a conflict are minus -1's. - // As a result, we start with a fast insertion (aka without check) of - // non-expanded values, since they just do a literal shift which needs no - // extra checking, besides bounds checks. - // - // Since we're doing things manually, we also need to manually preserve TT - // invariants. Specifically, TT limits all values to have offsets < - // MAX_OFFSET, unless it is the smallest offset at that depth. (e.g. so we - // can still hava typetree {[123456]:Int}, even if limit is 100). - // - // First compute the minimum 0th index to be kept. - Result.minIndices.resize(minIndices.size(), INT_MAX); - - for (const auto &pair : mapping) { - if (pair.first.size() == 0) { - if (pair.second == BaseType::Pointer || - pair.second == BaseType::Anything) { - Result.mapping.emplace(pair.first, pair.second); - continue; - } - - llvm::errs() << "could not unmerge " << str() << "\n"; - assert(0 && "ShiftIndices called on a nonpointer/anything"); - llvm_unreachable("ShiftIndices called on a nonpointer/anything"); - } - - int next0 = pair.first[0]; - - if (next0 == -1) { - if (maxSize == -1) { - // Max size does not clip the next index - - // If we have a follow up offset add, we lose the -1 since we only - // represent [0, inf) with -1 not the [addOffset, inf) required here - if (addOffset != 0) { - next0 = addOffset; - } - - } else { - // We're going to insert addOffset + 0...maxSize so the new minIndex - // is addOffset - Result.minIndices[0] = addOffset; - for (size_t i = 1, sz = pair.first.size(); i < sz; i++) - if (pair.first[i] < Result.minIndices[i]) - Result.minIndices[i] = pair.first[i]; - continue; - } - } else { - // Too small for range - if (next0 < offset) { - continue; - } - next0 -= offset; - - if (maxSize != -1) { - if (next0 >= maxSize) - continue; - } - - next0 += addOffset; - } - if (next0 < Result.minIndices[0]) - Result.minIndices[0] = next0; - for (size_t i = 1, sz = pair.first.size(); i < sz; i++) - if (pair.first[i] < Result.minIndices[i]) - Result.minIndices[i] = pair.first[i]; - } - - // Max depth of actual inserted values - size_t maxInsertedDepth = 0; - - // Insert all - for (const auto &pair : mapping) { - if (pair.first.size() == 0) - continue; - - int next0 = pair.first[0]; - - if (next0 == -1) { - if (maxSize == -1) { - // Max size does not clip the next index - - // If we have a follow up offset add, we lose the -1 since we only - // represent [0, inf) with -1 not the [addOffset, inf) required here - if (addOffset != 0) { - next0 = addOffset; - } - - } else { - // This needs to become 0...maxSize handled separately as it is the - // only insertion that could have collisions - continue; - } - } else { - // Too small for range - if (next0 < offset) { - continue; - } - next0 -= offset; - - if (maxSize != -1) { - if (next0 >= maxSize) - continue; - } - - next0 += addOffset; - } - - // If after moving this would not merit being kept for being a min index - // or being within the max type offset, skip it. - if (next0 > MaxTypeOffset) { - bool minIndex = next0 == Result.minIndices[0]; - if (!minIndex) - for (size_t i = 1; i < pair.first.size(); i++) { - if (pair.first[i] == Result.minIndices[i]) { - minIndex = true; - break; - } - } - if (!minIndex) - continue; - } - - std::vector next(pair.first); - next[0] = next0; - Result.mapping.emplace(next, pair.second); - if (next.size() > maxInsertedDepth) - maxInsertedDepth = next.size(); - } - - // Insert and expand the minus one, if needed - if (maxSize != -1) - for (const auto &pair : mapping) { - if (pair.first.size() == 0) - continue; - if (pair.first[0] != -1) - continue; - - size_t chunk = 1; - std::vector next(pair.first); - auto op = operator[]({next[0]}); - if (auto flt = op.isFloat()) { - chunk = dl.getTypeSizeInBits(flt) / 8; - } else if (op == BaseType::Pointer) { - chunk = dl.getPointerSizeInBits() / 8; - } - auto offincr = (chunk - offset % chunk) % chunk; - bool inserted = false; - for (int i = offincr; i < maxSize; i += chunk) { - next[0] = i + addOffset; - ConcreteType prev(pair.second); - // We can use faster checks here, since we know there can be no - // -1's that we would conflict with, only conflicts from previous - // fixed value insertions. - auto found = Result.mapping.find(next); - if (found != Result.mapping.end()) { - // orIn returns if changed, update the value in the map if so - // with the new value. - if (prev.orIn(found->second, /*pointerIntSame*/ false)) - found->second = prev; - } else { - Result.mapping.emplace(next, pair.second); - } - inserted = true; - } - if (inserted && next.size() > maxInsertedDepth) - maxInsertedDepth = next.size(); - } - - // Resize minIndices down if we dropped any higher-depth indices for being - // out of scope. - Result.minIndices.resize(maxInsertedDepth); - return Result; - } - - /// Keep only mappings where the type is not an `Anything` - TypeTree PurgeAnything() const { - TypeTree Result; - Result.minIndices.reserve(minIndices.size()); - for (const auto &pair : mapping) { - if (pair.second == ConcreteType(BaseType::Anything)) - continue; - Result.mapping.insert(pair); - for (size_t i = 0, Len = pair.first.size(); i < Len; ++i) { - if (i == Result.minIndices.size()) - Result.minIndices.push_back(pair.first[i]); - else if (pair.first[i] < Result.minIndices[i]) - Result.minIndices[i] = pair.first[i]; - } - } - return Result; - } - - /// Replace -1 with 0 - TypeTree ReplaceMinus() const { - TypeTree dat; - for (const auto &pair : mapping) { - if (pair.second == ConcreteType(BaseType::Anything)) - continue; - std::vector nex = pair.first; - for (auto &v : nex) - if (v == -1) - v = 0; - dat.insert(nex, pair.second); - } - return dat; - } - - /// Replace all integer subtypes with anything - void ReplaceIntWithAnything() { - for (auto &pair : mapping) { - if (pair.second == BaseType::Integer) { - pair.second = BaseType::Anything; - } - } - } - - /// Keep only mappings where the type is an `Anything` - TypeTree JustAnything() const { - TypeTree dat; - for (const auto &pair : mapping) { - if (pair.second != ConcreteType(BaseType::Anything)) - continue; - dat.insert(pair.first, pair.second); - } - return dat; - } - - /// Chceck equality of two TypeTrees - bool operator==(const TypeTree &RHS) const { return mapping == RHS.mapping; } - - /// Set this to another TypeTree, returning if this was changed - bool operator=(const TypeTree &RHS) { - if (*this == RHS) - return false; - minIndices = RHS.minIndices; - mapping.clear(); - for (const auto &elems : RHS.mapping) { - mapping.emplace(elems); - } - return true; - } - - bool checkedOrIn(const std::vector &Seq, ConcreteType RHS, - bool PointerIntSame, bool &LegalOr) { - assert(RHS != BaseType::Unknown); - ConcreteType CT = operator[](Seq); - - bool subchanged = CT.checkedOrIn(RHS, PointerIntSame, LegalOr); - if (!subchanged) - return false; - if (!LegalOr) - return subchanged; - - auto SeqSize = Seq.size(); - - if (SeqSize > 0) { - // check pointer abilities from before - for (size_t i = 0; i < SeqSize; ++i) { - std::vector tmp(Seq.begin(), Seq.end() - 1 - i); - auto found = mapping.find(tmp); - if (found != mapping.end()) { - if (!(found->second == BaseType::Pointer || - found->second == BaseType::Anything)) { - LegalOr = false; - return false; - } - } - } - - // Check if there is an existing match, e.g. [-1, -1, -1] and inserting - // [-1, 8, -1] - { - for (const auto &pair : llvm::make_early_inc_range(mapping)) { - if (pair.first.size() == SeqSize) { - // Whether the the inserted val (e.g. [-1, 0] or [0, 0]) is at least - // as general as the existing map val (e.g. [0, 0]). - bool newMoreGeneralThanOld = true; - // Whether the the existing val (e.g. [-1, 0] or [0, 0]) is at least - // as general as the inserted map val (e.g. [0, 0]). - bool oldMoreGeneralThanNew = true; - for (unsigned i = 0; i < SeqSize; i++) { - if (pair.first[i] == Seq[i]) - continue; - if (Seq[i] == -1) { - oldMoreGeneralThanNew = false; - } else if (pair.first[i] == -1) { - newMoreGeneralThanOld = false; - } else { - oldMoreGeneralThanNew = false; - newMoreGeneralThanOld = false; - break; - } - } - - if (oldMoreGeneralThanNew) { - // Inserting an existing or less general version - if (CT == pair.second) - return false; - - // Inserting an existing or less general version (with pointer-int - // equivalence) - if (PointerIntSame) - if ((CT == BaseType::Pointer && - pair.second == BaseType::Integer) || - (CT == BaseType::Integer && - pair.second == BaseType::Pointer)) - return false; - - // Inserting into an anything. Since from above we know this is - // not an anything, the inserted value contains no new information - if (pair.second == BaseType::Anything) - return false; - - // Inserting say a [0]:anything into a [-1]:Float - if (CT == BaseType::Anything) { - // If both at same index, remove old index - if (newMoreGeneralThanOld) - mapping.erase(pair.first); - continue; - } - - // Otherwise, inserting a non-equivalent pair into a more general - // slot. This is invalid. - LegalOr = false; - return false; - } else if (newMoreGeneralThanOld) { - // This new is strictly more general than the old. If they were - // equivalent, the case above would have been hit. - - if (CT == BaseType::Anything || CT == pair.second) { - // previous equivalent values or values overwritten by - // an anything are removed - mapping.erase(pair.first); - continue; - } - - // Inserting an existing or less general version (with pointer-int - // equivalence) - if (PointerIntSame) - if ((CT == BaseType::Pointer && - pair.second == BaseType::Integer) || - (CT == BaseType::Integer && - pair.second == BaseType::Pointer)) { - mapping.erase(pair.first); - continue; - } - - // Keep lingering anythings if not being overwritten, even if this - // (e.g. Float) applies to more locations. Therefore it is legal - // to have [-1]:Float, [8]:Anything - if (CT != BaseType::Anything && pair.second == BaseType::Anything) - continue; - - // Otherwise, inserting a more general non-equivalent pair. This - // is invalid. - LegalOr = false; - return false; - } - } - } - } - } - - return insert(Seq, CT); - } - - bool orIn(const std::vector &Seq, ConcreteType RHS, - bool PointerIntSame = false) { - bool LegalOr = true; - bool Result = checkedOrIn(Seq, RHS, PointerIntSame, LegalOr); - assert(LegalOr); - return Result; - } - - /// Set this to the logical or of itself and RHS, returning whether this value - /// changed Setting `PointerIntSame` considers pointers and integers as - /// equivalent If this is an illegal operation, `LegalOr` will be set to false - bool checkedOrIn(const TypeTree &RHS, bool PointerIntSame, bool &LegalOr) { - // TODO detect recursive merge and simplify - - bool changed = false; - for (auto &pair : RHS.mapping) { - changed |= checkedOrIn(pair.first, pair.second, PointerIntSame, LegalOr); - } - return changed; - } - - /// Set this to the logical or of itself and RHS, returning whether this value - /// changed Setting `PointerIntSame` considers pointers and integers as - /// equivalent This function will error if doing an illegal Operation - bool orIn(const TypeTree &RHS, bool PointerIntSame) { - bool Legal = true; - bool Result = checkedOrIn(RHS, PointerIntSame, Legal); - if (!Legal) { - llvm::errs() << "Illegal orIn: " << str() << " right: " << RHS.str() - << " PointerIntSame=" << PointerIntSame << "\n"; - assert(0 && "Performed illegal ConcreteType::orIn"); - llvm_unreachable("Performed illegal ConcreteType::orIn"); - } - return Result; - } - - /// Set this to the logical or of itself and RHS, returning whether this value - /// changed Setting `PointerIntSame` considers pointers and integers as - /// equivalent This function will error if doing an illegal Operation - bool orIn(const std::vector Seq, ConcreteType CT, bool PointerIntSame) { - bool Legal = true; - bool Result = checkedOrIn(Seq, CT, PointerIntSame, Legal); - if (!Legal) { - llvm::errs() << "Illegal orIn: " << str() << " right: " << to_string(Seq) - << " CT: " << CT.str() - << " PointerIntSame=" << PointerIntSame << "\n"; - assert(0 && "Performed illegal ConcreteType::orIn"); - llvm_unreachable("Performed illegal ConcreteType::orIn"); - } - return Result; - } - - /// Set this to the logical or of itself and RHS, returning whether this value - /// changed This assumes that pointers and integers are distinct This function - /// will error if doing an illegal Operation - bool operator|=(const TypeTree &RHS) { - return orIn(RHS, /*PointerIntSame*/ false); - } - - /// Set this to the logical and of itself and RHS, returning whether this - /// value changed If this and RHS are incompatible at an index, the result - /// will be BaseType::Unknown - bool andIn(const TypeTree &RHS) { - bool changed = false; - - for (auto &pair : llvm::make_early_inc_range(mapping)) { - ConcreteType other = BaseType::Unknown; - auto fd = RHS.mapping.find(pair.first); - if (fd != RHS.mapping.end()) { - other = fd->second; - } - changed = (pair.second &= other); - if (pair.second == BaseType::Unknown) { - mapping.erase(pair.first); - } - } - - return changed; - } - - /// Set this to the logical and of itself and RHS, returning whether this - /// value changed If this and RHS are incompatible at an index, the result - /// will be BaseType::Unknown - bool operator&=(const TypeTree &RHS) { return andIn(RHS); } - - /// Set this to the logical `binop` of itself and RHS, using the Binop Op, - /// returning true if this was changed. - /// This function will error on an invalid type combination - bool binopIn(bool &Legal, const TypeTree &RHS, - llvm::BinaryOperator::BinaryOps Op) { - bool changed = false; - - for (auto &pair : llvm::make_early_inc_range(mapping)) { - // TODO propagate non-first level operands: - // Special handling is necessary here because a pointer to an int - // binop with something should not apply the binop rules to the - // underlying data but instead a different rule - if (pair.first.size() > 0) { - mapping.erase(pair.first); - continue; - } - - ConcreteType CT(pair.second); - ConcreteType RightCT(BaseType::Unknown); - - // Mutual mappings - auto found = RHS.mapping.find(pair.first); - if (found != RHS.mapping.end()) { - RightCT = found->second; - } - bool SubLegal = true; - changed |= CT.binopIn(SubLegal, RightCT, Op); - if (!SubLegal) { - Legal = false; - return changed; - } - if (CT == BaseType::Unknown) { - mapping.erase(pair.first); - } else { - pair.second = CT; - } - } - - // mapings just on the right - for (auto &pair : RHS.mapping) { - // TODO propagate non-first level operands: - // Special handling is necessary here because a pointer to an int - // binop with something should not apply the binop rules to the - // underlying data but instead a different rule - if (pair.first.size() > 0) { - continue; - } - - if (mapping.find(pair.first) == RHS.mapping.end()) { - ConcreteType CT = BaseType::Unknown; - bool SubLegal = true; - changed |= CT.binopIn(SubLegal, pair.second, Op); - if (!SubLegal) { - Legal = false; - return changed; - } - if (CT != BaseType::Unknown) { - mapping.insert(std::make_pair(pair.first, CT)); - } - } - } - - return changed; - } - - /// Returns a string representation of this TypeTree - std::string str() const { - std::string out = "{"; - bool first = true; - for (auto &pair : mapping) { - if (!first) { - out += ", "; - } - out += "["; - for (unsigned i = 0; i < pair.first.size(); ++i) { - if (i != 0) - out += ","; - out += std::to_string(pair.first[i]); - } - out += "]:" + pair.second.str(); - first = false; - } - out += "}"; - return out; - } - - llvm::MDNode *toMD(llvm::LLVMContext &ctx) { - llvm::SmallVector subMD; - std::map todo; - ConcreteType base(BaseType::Unknown); - for (auto &pair : mapping) { - if (pair.first.size() == 0) { - base = pair.second; - continue; - } - auto next(pair.first); - next.erase(next.begin()); - todo[pair.first[0]].mapping.insert(std::make_pair(next, pair.second)); - } - subMD.push_back(llvm::MDString::get(ctx, base.str())); - for (auto pair : todo) { - subMD.push_back(llvm::ConstantAsMetadata::get( - llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32), pair.first))); - subMD.push_back(pair.second.toMD(ctx)); - } - return llvm::MDNode::get(ctx, subMD); - }; - - void insertFromMD(llvm::MDNode *md, const std::vector &prev = {}) { - ConcreteType base( - llvm::cast(md->getOperand(0))->getString(), - md->getContext()); - if (base != BaseType::Unknown) - mapping.insert(std::make_pair(prev, base)); - for (size_t i = 1; i < md->getNumOperands(); i += 2) { - auto off = llvm::cast( - llvm::cast(md->getOperand(i)) - ->getValue()) - ->getSExtValue(); - auto next(prev); - next.push_back((int)off); - insertFromMD(llvm::cast(md->getOperand(i + 1)), next); - } - } - - static TypeTree fromMD(llvm::MDNode *md) { - TypeTree ret; - std::vector off; - ret.insertFromMD(md, off); - return ret; - } -}; - -#endif diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp deleted file mode 100644 index 46cc44ca58db..000000000000 --- a/enzyme/Enzyme/Utils.cpp +++ /dev/null @@ -1,4044 +0,0 @@ -//===- Utils.cpp - Definition of miscellaneous utilities ------------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file defines miscellaneous utilities that are used as part of the -// AD process. -// -//===----------------------------------------------------------------------===// -#include "Utils.h" -#include "GradientUtils.h" -#include "TypeAnalysis/TypeAnalysis.h" - -#if LLVM_VERSION_MAJOR >= 16 -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#else -#include "SCEV/ScalarEvolution.h" -#include "SCEV/ScalarEvolutionExpander.h" -#endif - -#include "TypeAnalysis/TBAA.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InlineAsm.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Verifier.h" - -#if LLVM_VERSION_MAJOR >= 16 -#include "llvm/TargetParser/Triple.h" -#else -#include "llvm/ADT/Triple.h" -#endif - -#include "llvm-c/Core.h" - -#include "LibraryFuncs.h" - -using namespace llvm; - -extern "C" { -LLVMValueRef (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, - const void *, LLVMValueRef, - LLVMBuilderRef) = nullptr; -LLVMValueRef (*CustomAllocator)(LLVMBuilderRef, LLVMTypeRef, - /*Count*/ LLVMValueRef, - /*Align*/ LLVMValueRef, uint8_t, - LLVMValueRef *) = nullptr; -void (*CustomZero)(LLVMBuilderRef, LLVMTypeRef, - /*Ptr*/ LLVMValueRef, uint8_t) = nullptr; -LLVMValueRef (*CustomDeallocator)(LLVMBuilderRef, LLVMValueRef) = nullptr; -void (*CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef, - LLVMValueRef) = nullptr; -LLVMValueRef *(*EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef, - uint64_t *size) = nullptr; -LLVMTypeRef (*EnzymeDefaultTapeType)(LLVMContextRef) = nullptr; -LLVMValueRef (*EnzymeUndefinedValueForType)(LLVMModuleRef, LLVMTypeRef, - uint8_t) = nullptr; - -LLVMValueRef (*EnzymeSanitizeDerivatives)(LLVMValueRef, LLVMValueRef toset, - LLVMBuilderRef, - LLVMValueRef) = nullptr; - -extern llvm::cl::opt EnzymeZeroCache; - -// default to false because lacpy is slow -llvm::cl::opt - EnzymeLapackCopy("enzyme-lapack-copy", cl::init(false), cl::Hidden, - cl::desc("Use blas copy calls to cache matrices")); -llvm::cl::opt - EnzymeBlasCopy("enzyme-blas-copy", cl::init(true), cl::Hidden, - cl::desc("Use blas copy calls to cache vectors")); -llvm::cl::opt - EnzymeFastMath("enzyme-fast-math", cl::init(true), cl::Hidden, - cl::desc("Use fast math on derivative compuation")); -llvm::cl::opt - EnzymeStrongZero("enzyme-strong-zero", cl::init(false), cl::Hidden, - cl::desc("Use additional checks to ensure correct " - "behavior when handling functions with inf")); -llvm::cl::opt EnzymeMemmoveWarning( - "enzyme-memmove-warning", cl::init(true), cl::Hidden, - cl::desc("Warn if using memmove implementation as a fallback for memmove")); -llvm::cl::opt EnzymeRuntimeError( - "enzyme-runtime-error", cl::init(false), cl::Hidden, - cl::desc("Emit Runtime errors instead of compile time ones")); - -llvm::cl::opt EnzymeNonPower2Cache( - "enzyme-non-power2-cache", cl::init(false), cl::Hidden, - cl::desc("Disable caching of integers which are not a power of 2")); -} - -void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, - bool isTape) { - if (CustomZero) { - CustomZero(wrap(&Builder), wrap(T), wrap(obj), isTape); - } else { - Builder.CreateStore(Constant::getNullValue(T), obj); - } -} - -llvm::SmallVector PostCacheStore(llvm::StoreInst *SI, - llvm::IRBuilder<> &B) { - SmallVector res; - if (EnzymePostCacheStore) { - uint64_t size = 0; - auto ptr = EnzymePostCacheStore(wrap(SI), wrap(&B), &size); - for (size_t i = 0; i < size; i++) { - res.push_back(cast(unwrap(ptr[i]))); - } - free(ptr); - } - return res; -} - -llvm::PointerType *getDefaultAnonymousTapeType(llvm::LLVMContext &C) { - if (EnzymeDefaultTapeType) - return cast(unwrap(EnzymeDefaultTapeType(wrap(&C)))); - return getInt8PtrTy(C); -} - -Function *getOrInsertExponentialAllocator(Module &M, Function *newFunc, - bool ZeroInit, llvm::Type *RT) { - bool custom = true; - llvm::PointerType *allocType; - { - auto i64 = Type::getInt64Ty(newFunc->getContext()); - BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", newFunc); - IRBuilder<> B(BB); - auto P = B.CreatePHI(i64, 1); - CallInst *malloccall; - Instruction *SubZero = nullptr; - CreateAllocation(B, RT, P, "tapemem", &malloccall, &SubZero)->getType(); - if (auto F = getFunctionFromCall(malloccall)) { - custom = F->getName() != "malloc"; - } - allocType = cast(malloccall->getType()); - BB->eraseFromParent(); - } - - Type *types[] = {allocType, Type::getInt64Ty(M.getContext()), - Type::getInt64Ty(M.getContext())}; - std::string name = "__enzyme_exponentialallocation"; - if (ZeroInit) - name += "zero"; - if (custom) - name += ".custom@" + std::to_string((size_t)RT); - - FunctionType *FT = FunctionType::get(allocType, types, false); - Function *F = cast(M.getOrInsertFunction(name, FT).getCallee()); - - if (!F->empty()) - return F; - - F->setLinkage(Function::LinkageTypes::InternalLinkage); - F->addFnAttr(Attribute::AlwaysInline); - F->addFnAttr(Attribute::NoUnwind); - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *grow = BasicBlock::Create(M.getContext(), "grow", F); - BasicBlock *ok = BasicBlock::Create(M.getContext(), "ok", F); - - IRBuilder<> B(entry); - - Argument *ptr = F->arg_begin(); - ptr->setName("ptr"); - Argument *size = ptr + 1; - size->setName("size"); - Argument *tsize = size + 1; - tsize->setName("tsize"); - - Value *hasOne = B.CreateICmpNE( - B.CreateAnd(size, ConstantInt::get(size->getType(), 1, false)), - ConstantInt::get(size->getType(), 0, false)); - auto popCnt = getIntrinsicDeclaration(&M, Intrinsic::ctpop, {types[1]}); - - B.CreateCondBr( - B.CreateAnd(B.CreateICmpULT(B.CreateCall(popCnt, {size}), - ConstantInt::get(types[1], 3, false)), - hasOne), - grow, ok); - - B.SetInsertPoint(grow); - - auto lz = - B.CreateCall(getIntrinsicDeclaration(&M, Intrinsic::ctlz, {types[1]}), - {size, ConstantInt::getTrue(M.getContext())}); - Value *next = - B.CreateShl(tsize, B.CreateSub(ConstantInt::get(types[1], 64, false), lz, - "", true, true)); - - Value *gVal; - - Value *prevSize = - B.CreateSelect(B.CreateICmpEQ(size, ConstantInt::get(size->getType(), 1)), - ConstantInt::get(next->getType(), 0), - B.CreateLShr(next, ConstantInt::get(next->getType(), 1))); - - auto Arch = llvm::Triple(M.getTargetTriple()).getArch(); - bool forceMalloc = Arch == Triple::nvptx || Arch == Triple::nvptx64; - - if (!custom && !forceMalloc) { - auto reallocF = M.getOrInsertFunction("realloc", allocType, allocType, - Type::getInt64Ty(M.getContext())); - - Value *args[] = {B.CreatePointerCast(ptr, allocType), next}; - gVal = B.CreateCall(reallocF, args); - } else { - Value *tsize = ConstantInt::get( - next->getType(), - newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(RT) / 8); - auto elSize = B.CreateUDiv(next, tsize, "", /*isExact*/ true); - Instruction *SubZero = nullptr; - gVal = CreateAllocation(B, RT, elSize, "", nullptr, &SubZero); - - Type *bTy = - PointerType::get(Type::getInt8Ty(gVal->getContext()), - cast(gVal->getType())->getAddressSpace()); - gVal = B.CreatePointerCast(gVal, bTy); - auto pVal = B.CreatePointerCast(ptr, gVal->getType()); - - Value *margs[] = {gVal, pVal, prevSize, - ConstantInt::getFalse(M.getContext())}; - Type *tys[] = {margs[0]->getType(), margs[1]->getType(), - margs[2]->getType()}; - auto memsetF = getIntrinsicDeclaration(&M, Intrinsic::memcpy, tys); - B.CreateCall(memsetF, margs); - if (SubZero) { - ZeroInit = false; - IRBuilder<> BB(SubZero); - Value *zeroSize = BB.CreateSub(next, prevSize); - Value *tmp = SubZero->getOperand(0); - Type *tmpT = tmp->getType(); - tmp = BB.CreatePointerCast(tmp, bTy); - tmp = BB.CreateInBoundsGEP(Type::getInt8Ty(tmp->getContext()), tmp, - prevSize); - tmp = BB.CreatePointerCast(tmp, tmpT); - SubZero->setOperand(0, tmp); - SubZero->setOperand(2, zeroSize); - } - } - - if (ZeroInit) { - Value *zeroSize = B.CreateSub(next, prevSize); - - Value *margs[] = {B.CreateInBoundsGEP(B.getInt8Ty(), gVal, prevSize), - B.getInt8(0), zeroSize, B.getFalse()}; - Type *tys[] = {margs[0]->getType(), margs[2]->getType()}; - auto memsetF = getIntrinsicDeclaration(&M, Intrinsic::memset, tys); - B.CreateCall(memsetF, margs); - } - gVal = B.CreatePointerCast(gVal, ptr->getType()); - - B.CreateBr(ok); - B.SetInsertPoint(ok); - auto phi = B.CreatePHI(ptr->getType(), 2); - phi->addIncoming(gVal, grow); - phi->addIncoming(ptr, entry); - B.CreateRet(phi); - return F; -} - -llvm::Value *CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev, - llvm::Type *T, llvm::Value *OuterCount, - llvm::Value *InnerCount, - const llvm::Twine &Name, - llvm::CallInst **caller, bool ZeroMem) { - auto newFunc = B.GetInsertBlock()->getParent(); - - Value *tsize = ConstantInt::get( - InnerCount->getType(), - newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(T) / 8); - - Value *idxs[] = { - /*ptr*/ - prev, - /*incrementing value to increase when it goes past a power of two*/ - OuterCount, - /*buffer size (element x subloops)*/ - B.CreateMul(tsize, InnerCount, "", /*NUW*/ true, - /*NSW*/ true)}; - - auto realloccall = - B.CreateCall(getOrInsertExponentialAllocator(*newFunc->getParent(), - newFunc, ZeroMem, T), - idxs, Name); - if (caller) - *caller = realloccall; - return realloccall; -} - -Value *CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count, - const Twine &Name, CallInst **caller, - Instruction **ZeroMem, bool isDefault) { - Value *res; - auto &M = *Builder.GetInsertBlock()->getParent()->getParent(); - auto AlignI = M.getDataLayout().getTypeAllocSizeInBits(T) / 8; - auto Align = ConstantInt::get(Count->getType(), AlignI); - CallInst *malloccall = nullptr; - if (CustomAllocator) { - LLVMValueRef wzeromem = nullptr; - res = unwrap(CustomAllocator(wrap(&Builder), wrap(T), wrap(Count), - wrap(Align), isDefault, - ZeroMem ? &wzeromem : nullptr)); - if (isa(res)) - return res; - if (isa(res)) - return res; - if (auto I = dyn_cast(res)) - I->setName(Name); - - malloccall = dyn_cast(res); - if (malloccall == nullptr) { - malloccall = cast(cast(res)->getOperand(0)); - } - if (ZeroMem) { - *ZeroMem = cast_or_null(unwrap(wzeromem)); - ZeroMem = nullptr; - } - } else { -#if LLVM_VERSION_MAJOR > 17 - res = - Builder.CreateMalloc(Count->getType(), T, Align, Count, nullptr, Name); -#else - if (Builder.GetInsertPoint() == Builder.GetInsertBlock()->end()) { - res = CallInst::CreateMalloc(Builder.GetInsertBlock(), Count->getType(), - T, Align, Count, nullptr, Name); - Builder.SetInsertPoint(Builder.GetInsertBlock()); - } else { - res = CallInst::CreateMalloc(&*Builder.GetInsertPoint(), Count->getType(), - T, Align, Count, nullptr, Name); - } - if (!cast(res)->getParent()) - Builder.Insert(cast(res)); -#endif - - malloccall = dyn_cast(res); - if (malloccall == nullptr) { - malloccall = cast(cast(res)->getOperand(0)); - } - - // Assert computation of size of array doesn't wrap - if (auto BI = dyn_cast(malloccall->getArgOperand(0))) { - if (BI->getOpcode() == BinaryOperator::Mul) { - if ((BI->getOperand(0) == Align && BI->getOperand(1) == Count) || - (BI->getOperand(1) == Align && BI->getOperand(0) == Count)) - BI->setHasNoSignedWrap(true); - BI->setHasNoUnsignedWrap(true); - } - } - - if (auto ci = dyn_cast(Count)) { -#if LLVM_VERSION_MAJOR >= 14 - malloccall->addDereferenceableRetAttr(ci->getLimitedValue() * AlignI); -#if !defined(FLANG) && !defined(ROCM) - AttrBuilder B(ci->getContext()); -#else - AttrBuilder B; -#endif - B.addDereferenceableOrNullAttr(ci->getLimitedValue() * AlignI); - malloccall->setAttributes(malloccall->getAttributes().addRetAttributes( - malloccall->getContext(), B)); -#else - malloccall->addDereferenceableAttr(llvm::AttributeList::ReturnIndex, - ci->getLimitedValue() * AlignI); - malloccall->addDereferenceableOrNullAttr(llvm::AttributeList::ReturnIndex, - ci->getLimitedValue() * AlignI); -#endif - // malloccall->removeAttribute(llvm::AttributeList::ReturnIndex, - // Attribute::DereferenceableOrNull); - } -#if LLVM_VERSION_MAJOR >= 14 - malloccall->addAttributeAtIndex(AttributeList::ReturnIndex, - Attribute::NoAlias); - malloccall->addAttributeAtIndex(AttributeList::ReturnIndex, - Attribute::NonNull); -#else - malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); - malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); -#endif - } - if (caller) { - *caller = malloccall; - } - if (ZeroMem) { - auto PT = cast(malloccall->getType()); - Value *tozero = malloccall; - - bool needsCast = false; -#if LLVM_VERSION_MAJOR < 17 -#if LLVM_VERSION_MAJOR >= 15 - if (PT->getContext().supportsTypedPointers()) { -#endif - needsCast = !PT->getPointerElementType()->isIntegerTy(8); -#if LLVM_VERSION_MAJOR >= 15 - } -#endif -#endif - if (needsCast) - tozero = Builder.CreatePointerCast( - tozero, PointerType::get(Type::getInt8Ty(PT->getContext()), - PT->getAddressSpace())); - Value *args[] = { - tozero, ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0), - Builder.CreateMul(Align, Count, "", true, true), - ConstantInt::getFalse(malloccall->getContext())}; - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - - *ZeroMem = Builder.CreateCall( - getIntrinsicDeclaration(&M, Intrinsic::memset, tys), args); - } - return res; -} - -CallInst *CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree) { - CallInst *res = nullptr; - - if (CustomDeallocator) { - res = dyn_cast_or_null( - unwrap(CustomDeallocator(wrap(&Builder), wrap(ToFree)))); - } else { - - ToFree = - Builder.CreatePointerCast(ToFree, getInt8PtrTy(ToFree->getContext())); -#if LLVM_VERSION_MAJOR > 17 - res = cast(Builder.CreateFree(ToFree)); -#else - if (Builder.GetInsertPoint() == Builder.GetInsertBlock()->end()) { - res = cast( - CallInst::CreateFree(ToFree, Builder.GetInsertBlock())); - Builder.SetInsertPoint(Builder.GetInsertBlock()); - } else { - res = cast( - CallInst::CreateFree(ToFree, &*Builder.GetInsertPoint())); - } - if (!cast(res)->getParent()) - Builder.Insert(cast(res)); -#endif -#if LLVM_VERSION_MAJOR >= 14 - res->addAttributeAtIndex(AttributeList::FirstArgIndex, Attribute::NonNull); -#else - res->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); -#endif - } - return res; -} - -EnzymeFailure::EnzymeFailure(const llvm::Twine &RemarkName, - const llvm::DiagnosticLocation &Loc, - const llvm::Instruction *CodeRegion) - : EnzymeFailure(RemarkName, Loc, CodeRegion->getParent()->getParent()) {} - -EnzymeFailure::EnzymeFailure(const llvm::Twine &RemarkName, - const llvm::DiagnosticLocation &Loc, - const llvm::Function *CodeRegion) - : DiagnosticInfoUnsupported(*CodeRegion, RemarkName, Loc) {} - -/// Convert a floating type to a string -static inline std::string tofltstr(Type *T) { - if (auto VT = dyn_cast(T)) { -#if LLVM_VERSION_MAJOR >= 12 - auto len = VT->getElementCount().getFixedValue(); -#else - auto len = VT->getNumElements(); -#endif - return "vec" + std::to_string(len) + tofltstr(VT->getElementType()); - } - switch (T->getTypeID()) { - case Type::HalfTyID: - return "half"; - case Type::FloatTyID: - return "float"; - case Type::DoubleTyID: - return "double"; - case Type::X86_FP80TyID: - return "x87d"; - case Type::BFloatTyID: - return "bf16"; - case Type::FP128TyID: - return "quad"; - case Type::PPC_FP128TyID: - return "ppcddouble"; - default: - llvm_unreachable("Invalid floating type"); - } -} - -Constant *getString(Module &M, StringRef Str) { - llvm::Constant *s = llvm::ConstantDataArray::getString(M.getContext(), Str); - auto *gv = new llvm::GlobalVariable( - M, s->getType(), true, llvm::GlobalValue::PrivateLinkage, s, ".str"); - gv->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); - Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(M.getContext()), 0), - ConstantInt::get(Type::getInt32Ty(M.getContext()), 0)}; - return ConstantExpr::getInBoundsGetElementPtr(s->getType(), gv, Idxs); -} - -void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, - llvm::Value *shadow, const char *Message, - llvm::DebugLoc &&loc, llvm::Instruction *orig) { - Module &M = *B.GetInsertBlock()->getParent()->getParent(); - std::string name = "__enzyme_runtimeinactiveerr"; - if (CustomRuntimeInactiveError) { - static int count = 0; - name += std::to_string(count); - count++; - } - FunctionType *FT = FunctionType::get(Type::getVoidTy(M.getContext()), - {getInt8PtrTy(M.getContext()), - getInt8PtrTy(M.getContext()), - getInt8PtrTy(M.getContext())}, - false); - - Function *F = cast(M.getOrInsertFunction(name, FT).getCallee()); - - if (F->empty()) { - F->setLinkage(Function::LinkageTypes::InternalLinkage); - F->addFnAttr(Attribute::AlwaysInline); - addFunctionNoCapture(F, 0); - addFunctionNoCapture(F, 1); - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *error = BasicBlock::Create(M.getContext(), "error", F); - BasicBlock *end = BasicBlock::Create(M.getContext(), "end", F); - - auto prim = F->arg_begin(); - prim->setName("primal"); - auto shadow = prim + 1; - shadow->setName("shadow"); - auto msg = prim + 2; - msg->setName("msg"); - - IRBuilder<> EB(entry); - EB.CreateCondBr(EB.CreateICmpEQ(prim, shadow), error, end); - - EB.SetInsertPoint(error); - - if (CustomRuntimeInactiveError) { - CustomRuntimeInactiveError(wrap(&EB), wrap(msg), wrap(orig)); - } else { - FunctionType *FT = - FunctionType::get(Type::getInt32Ty(M.getContext()), - {getInt8PtrTy(M.getContext())}, false); - - auto PutsF = M.getOrInsertFunction("puts", FT); - EB.CreateCall(PutsF, msg); - - FunctionType *FT2 = - FunctionType::get(Type::getVoidTy(M.getContext()), - {Type::getInt32Ty(M.getContext())}, false); - - auto ExitF = M.getOrInsertFunction("exit", FT2); - EB.CreateCall(ExitF, - ConstantInt::get(Type::getInt32Ty(M.getContext()), 1)); - } - EB.CreateUnreachable(); - - EB.SetInsertPoint(end); - EB.CreateRetVoid(); - } - - Value *args[] = {B.CreatePointerCast(primal, getInt8PtrTy(M.getContext())), - B.CreatePointerCast(shadow, getInt8PtrTy(M.getContext())), - getString(M, Message)}; - auto call = B.CreateCall(F, args); - call->setDebugLoc(loc); -} - -Type *BlasInfo::fpType(LLVMContext &ctx, bool to_scalar) const { - if (floatType == "d" || floatType == "D") { - return Type::getDoubleTy(ctx); - } else if (floatType == "s" || floatType == "S") { - return Type::getFloatTy(ctx); - } else if (floatType == "c" || floatType == "C") { - if (to_scalar) - return Type::getFloatTy(ctx); - return VectorType::get(Type::getFloatTy(ctx), 2, false); - } else if (floatType == "z" || floatType == "Z") { - if (to_scalar) - return Type::getDoubleTy(ctx); - return VectorType::get(Type::getDoubleTy(ctx), 2, false); - } else { - assert(false && "Unreachable"); - return nullptr; - } -} - -IntegerType *BlasInfo::intType(LLVMContext &ctx) const { - if (is64) - return IntegerType::get(ctx, 64); - else - return IntegerType::get(ctx, 32); -} - -/// Create function for type that is equivalent to memcpy but adds to -/// destination rather than a direct copy; dst, src, numelems -Function *getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType, - unsigned dstalign, - unsigned srcalign, - unsigned dstaddr, unsigned srcaddr, - unsigned bitwidth) { - assert(elementType->isFloatingPointTy()); - std::string name = "__enzyme_memcpy"; - if (bitwidth != 64) - name += std::to_string(bitwidth); - name += "add_" + tofltstr(elementType) + "da" + std::to_string(dstalign) + - "sa" + std::to_string(srcalign); - if (dstaddr) - name += "dadd" + std::to_string(dstaddr); - if (srcaddr) - name += "sadd" + std::to_string(srcaddr); - FunctionType *FT = - FunctionType::get(Type::getVoidTy(M.getContext()), - {PointerType::get(elementType, dstaddr), - PointerType::get(elementType, srcaddr), - IntegerType::get(M.getContext(), bitwidth)}, - false); - - Function *F = cast(M.getOrInsertFunction(name, FT).getCallee()); - - if (!F->empty()) - return F; - - F->setLinkage(Function::LinkageTypes::InternalLinkage); -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesArgMemory(); -#else - F->addFnAttr(Attribute::ArgMemOnly); -#endif - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - addFunctionNoCapture(F, 0); - addFunctionNoCapture(F, 1); - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F); - BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); - - auto dst = F->arg_begin(); - dst->setName("dst"); - auto src = dst + 1; - src->setName("src"); - auto num = src + 1; - num->setName("num"); - - { - IRBuilder<> B(entry); - B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)), - end, body); - } - - auto elSize = (M.getDataLayout().getTypeSizeInBits(elementType) + 7) / 8; - { - IRBuilder<> B(body); - B.setFastMathFlags(getFast()); - PHINode *idx = B.CreatePHI(num->getType(), 2, "idx"); - idx->addIncoming(ConstantInt::get(num->getType(), 0), entry); - - Value *dsti = B.CreateInBoundsGEP(elementType, dst, idx, "dst.i"); - LoadInst *dstl = B.CreateLoad(elementType, dsti, "dst.i.l"); - StoreInst *dsts = B.CreateStore(Constant::getNullValue(elementType), dsti); - - if (dstalign) { - // If the element size is already aligned to current alignment, do nothing - // e.g. elsize = double = 8, dstalign = 2 - if (elSize % dstalign == 0) { - - } else if (dstalign % elSize == 0) { - // Otherwise if the dst alignment is a multiple of the element size, - // use the element size as the new alignment. e.g. elsize = double = 8 - // and alignment = 16 - dstalign = elSize; - } else { - // else alignment only applies for first element, and we lose after all - // other iterattions, assume nothing - dstalign = 1; - } - } - - if (srcalign) { - // If the element size is already aligned to current alignment, do nothing - // e.g. elsize = double = 8, dstalign = 2 - if (elSize % srcalign == 0) { - - } else if (srcalign % elSize == 0) { - // Otherwise if the dst alignment is a multiple of the element size, - // use the element size as the new alignment. e.g. elsize = double = 8 - // and alignment = 16 - srcalign = elSize; - } else { - // else alignment only applies for first element, and we lose after all - // other iterattions, assume nothing - srcalign = 1; - } - } - - if (dstalign) { - dstl->setAlignment(Align(dstalign)); - dsts->setAlignment(Align(dstalign)); - } - - Value *srci = B.CreateInBoundsGEP(elementType, src, idx, "src.i"); - LoadInst *srcl = B.CreateLoad(elementType, srci, "src.i.l"); - StoreInst *srcs = B.CreateStore(B.CreateFAdd(srcl, dstl), srci); - if (srcalign) { - srcl->setAlignment(Align(srcalign)); - srcs->setAlignment(Align(srcalign)); - } - - Value *next = - B.CreateNUWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next"); - idx->addIncoming(next, body); - B.CreateCondBr(B.CreateICmpEQ(num, next), end, body); - } - - { - IRBuilder<> B(end); - B.CreateRetVoid(); - } - return F; -} - -Value *lookup_with_layout(IRBuilder<> &B, Type *fpType, Value *layout, - Value *const base, Value *lda, Value *row, - Value *col) { - Type *intType = row->getType(); - Value *is_row_maj = - layout ? B.CreateICmpEQ(layout, ConstantInt::get(layout->getType(), 101)) - : B.getFalse(); - Value *offset = nullptr; - if (col) { - offset = B.CreateMul( - row, CreateSelect(B, is_row_maj, lda, ConstantInt::get(intType, 1))); - offset = B.CreateAdd( - offset, - B.CreateMul(col, CreateSelect(B, is_row_maj, - ConstantInt::get(intType, 1), lda))); - } else { - offset = B.CreateMul(row, lda); - } - if (!base) - return offset; - - Value *ptr = base; - if (base->getType()->isIntegerTy()) - ptr = B.CreateIntToPtr(ptr, PointerType::getUnqual(fpType)); - -#if LLVM_VERSION_MAJOR < 17 -#if LLVM_VERSION_MAJOR >= 15 - if (ptr->getContext().supportsTypedPointers()) { -#endif - if (fpType != ptr->getType()->getPointerElementType()) { - ptr = B.CreatePointerCast( - ptr, - PointerType::get( - fpType, cast(ptr->getType())->getAddressSpace())); - } -#if LLVM_VERSION_MAJOR >= 15 - } -#endif -#endif - ptr = B.CreateGEP(fpType, ptr, offset); - - if (base->getType()->isIntegerTy()) { - ptr = B.CreatePtrToInt(ptr, base->getType()); - } else if (ptr->getType() != base->getType()) { - ptr = B.CreatePointerCast(ptr, base->getType()); - } - return ptr; -} - -void copy_lower_to_upper(llvm::IRBuilder<> &B, llvm::Type *fpType, - BlasInfo blas, bool byRef, llvm::Value *layout, - llvm::Value *islower, llvm::Value *A, llvm::Value *lda, - llvm::Value *N) { - - const bool cublasv2 = - blas.prefix == "cublas" && StringRef(blas.suffix).contains("v2"); - - const bool cublas = blas.prefix == "cublas"; - auto &M = *B.GetInsertBlock()->getParent()->getParent(); - - llvm::Type *intType = N->getType(); - // add spmv diag update call if not already present - auto fnc_name = "__enzyme_copy_lower_to_upper" + blas.floatType + - blas.prefix + blas.suffix; - - SmallVector tys = {islower->getType(), A->getType(), - lda->getType(), N->getType()}; - if (layout) - tys.insert(tys.begin(), layout->getType()); - auto ltuFT = FunctionType::get(B.getVoidTy(), tys, false); - - auto F0 = M.getOrInsertFunction(fnc_name, ltuFT); - - SmallVector args = {islower, A, lda, N}; - if (layout) - args.insert(args.begin(), layout); - auto C = B.CreateCall(F0, args); - auto F = getFunctionFromCall(C); - assert(F); - if (!F->empty()) { - return; - } - - // now add the implementation for the call - F->setLinkage(Function::LinkageTypes::InternalLinkage); -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesArgMemory(); -#else - F->addFnAttr(Attribute::ArgMemOnly); -#endif - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - if (A->getType()->isPointerTy()) - addFunctionNoCapture(F, 1 + ((bool)layout)); - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *loop = BasicBlock::Create(M.getContext(), "loop", F); - BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); - - auto arg = F->arg_begin(); - Argument *layoutarg = nullptr; - if (layout) { - layoutarg = arg; - layoutarg->setName("layout"); - arg++; - } - auto islowerarg = arg; - islowerarg->setName("islower"); - arg++; - auto Aarg = arg; - Aarg->setName("A"); - arg++; - auto ldaarg = arg; - ldaarg->setName("lda"); - arg++; - auto Narg = arg; - Narg->setName("N"); - - IRBuilder<> EB(entry); - - auto one = ConstantInt::get(intType, 1); - auto zero = ConstantInt::get(intType, 0); - - Value *N_minus_1 = EB.CreateSub(Narg, one); - - IRBuilder<> LB(loop); - - auto i = LB.CreatePHI(intType, 2); - i->addIncoming(zero, entry); - auto i_plus_one = LB.CreateAdd(i, one, "", true, true); - i->addIncoming(i_plus_one, loop); - - Value *copyArgs[] = { - to_blas_callconv(LB, LB.CreateSub(N_minus_1, i), byRef, cublas, nullptr, - EB), - lookup_with_layout(LB, fpType, layoutarg, Aarg, ldaarg, - CreateSelect(LB, islowerarg, i_plus_one, i), - CreateSelect(LB, islowerarg, i, i_plus_one)), - to_blas_callconv( - LB, - lookup_with_layout(LB, fpType, layoutarg, nullptr, ldaarg, - CreateSelect(LB, islowerarg, one, zero), - CreateSelect(LB, islowerarg, zero, one)), - byRef, cublas, nullptr, EB), - lookup_with_layout(LB, fpType, layoutarg, Aarg, ldaarg, - CreateSelect(LB, islowerarg, i, i_plus_one), - CreateSelect(LB, islowerarg, i_plus_one, i)), - to_blas_callconv( - LB, - lookup_with_layout(LB, fpType, layoutarg, nullptr, ldaarg, - CreateSelect(LB, islowerarg, zero, one), - CreateSelect(LB, islowerarg, one, zero)), - byRef, cublas, nullptr, EB)}; - - Type *copyTys[] = {copyArgs[0]->getType(), copyArgs[1]->getType(), - copyArgs[2]->getType(), copyArgs[3]->getType(), - copyArgs[4]->getType()}; - - FunctionType *FT = FunctionType::get(B.getVoidTy(), copyTys, false); - - auto copy_name = std::string(blas.prefix) + blas.floatType + "copy" + - (cublasv2 ? "" : blas.suffix); - - auto copyfn = M.getOrInsertFunction(copy_name, FT); - if (Function *copyF = dyn_cast(copyfn.getCallee())) - attributeKnownFunctions(*copyF); - LB.CreateCall(copyfn, copyArgs); - LB.CreateCondBr(LB.CreateICmpEQ(i_plus_one, N_minus_1), end, loop); - - EB.CreateCondBr(EB.CreateICmpSLE(N_minus_1, zero), end, loop); - { - IRBuilder<> B(end); - B.CreateRetVoid(); - } - - if (llvm::verifyFunction(*F, &llvm::errs())) { - llvm::errs() << *F << "\n"; - report_fatal_error("helper function failed verification"); - } -} - -void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, - llvm::ArrayRef args, - llvm::Type *copy_retty, - llvm::ArrayRef bundles) { - const bool cublasv2 = - blas.prefix == "cublas" && StringRef(blas.suffix).contains("v2"); - auto copy_name = std::string(blas.prefix) + blas.floatType + "copy" + - (cublasv2 ? "" : blas.suffix); - - SmallVector tys; - for (auto arg : args) - tys.push_back(arg->getType()); - - FunctionType *FT = FunctionType::get(copy_retty, tys, false); - auto fn = M.getOrInsertFunction(copy_name, FT); - Value *callVal = fn.getCallee(); - Function *called = nullptr; - while (!called) { - if (auto castinst = dyn_cast(callVal)) - if (castinst->isCast()) { - callVal = castinst->getOperand(0); - continue; - } - if (auto fn = dyn_cast(callVal)) { - called = fn; - break; - } - if (auto alias = dyn_cast(callVal)) { - callVal = alias->getAliasee(); - continue; - } - break; - } - attributeKnownFunctions(*called); - - B.CreateCall(fn, args, bundles); -} - -void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M, - BlasInfo blas, llvm::ArrayRef args, - llvm::ArrayRef bundles) { - auto copy_name = - std::string(blas.prefix) + blas.floatType + "lacpy" + blas.suffix; - - SmallVector tys; - for (auto arg : args) - tys.push_back(arg->getType()); - - auto FT = FunctionType::get(Type::getVoidTy(M.getContext()), tys, false); - auto fn = M.getOrInsertFunction(copy_name, FT); - if (auto F = GetFunctionFromValue(fn.getCallee())) - attributeKnownFunctions(*F); - - B.CreateCall(fn, args, bundles); -} - -void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas, - IntegerType *IT, Type *BlasCT, Type *BlasFPT, - Type *BlasPT, Type *BlasIT, Type *fpTy, - ArrayRef args, - ArrayRef bundles, bool byRef, - bool julia_decl) { - // add spmv diag update call if not already present - auto fnc_name = "__enzyme_spmv_diag" + blas.floatType + blas.suffix; - - // spmvDiagHelper(uplo, n, alpha, x, incx, ya, incy, APa) - auto FDiagUpdateT = FunctionType::get( - B.getVoidTy(), - {BlasCT, BlasIT, BlasFPT, BlasPT, BlasIT, BlasPT, BlasIT, BlasPT}, false); - Function *F = - cast(M.getOrInsertFunction(fnc_name, FDiagUpdateT).getCallee()); - - if (!F->empty()) { - B.CreateCall(F, args, bundles); - return; - } - - // now add the implementation for the call - F->setLinkage(Function::LinkageTypes::InternalLinkage); -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesArgMemory(); -#else - F->addFnAttr(Attribute::ArgMemOnly); -#endif - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - if (!julia_decl) { - addFunctionNoCapture(F, 3); - addFunctionNoCapture(F, 5); - addFunctionNoCapture(F, 7); - F->addParamAttr(3, Attribute::NoAlias); - F->addParamAttr(5, Attribute::NoAlias); - F->addParamAttr(7, Attribute::NoAlias); - F->addParamAttr(3, Attribute::ReadOnly); - F->addParamAttr(5, Attribute::ReadOnly); - if (byRef) { - addFunctionNoCapture(F, 2); - F->addParamAttr(2, Attribute::NoAlias); - F->addParamAttr(2, Attribute::ReadOnly); - } - } - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *init = BasicBlock::Create(M.getContext(), "init", F); - BasicBlock *uper_code = BasicBlock::Create(M.getContext(), "uper", F); - BasicBlock *lower_code = BasicBlock::Create(M.getContext(), "lower", F); - BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); - - // spmvDiagHelper(uplo, n, alpha, x, incx, ya, incy, APa) - auto blasuplo = F->arg_begin(); - blasuplo->setName("blasuplo"); - auto blasn = blasuplo + 1; - blasn->setName("blasn"); - auto blasalpha = blasn + 1; - blasalpha->setName("blasalpha"); - auto blasx = blasalpha + 1; - blasx->setName("blasx"); - auto blasincx = blasx + 1; - blasincx->setName("blasincx"); - auto blasdy = blasx + 1; - blasdy->setName("blasdy"); - auto blasincy = blasdy + 1; - blasincy->setName("blasincy"); - auto blasdAP = blasincy + 1; - blasdAP->setName("blasdAP"); - - // TODO: consider cblas_layout - - // https://dl.acm.org/doi/pdf/10.1145/3382191 - // Following example is Fortran based, thus 1 indexed - // if(uplo == 'u' .or. uplo == 'U') then - // k = 0 - // do i = 1,n - // k = k+i - // APa(k) = APa(k) - alpha*x(1 + (i-1)*incx)*ya(1 + (i-1)*incy) - // end do - // else - // k = 1 - // do i = 1,n - // APa(k) = APa(k) - alpha*x(1 + (i-1)*incx)*ya(1 + (i-1)*incy) - // k = k+n-i+1 - // end do - // end if - { - IRBuilder<> B1(entry); - Value *n = load_if_ref(B1, IT, blasn, byRef); - Value *incx = load_if_ref(B1, IT, blasincx, byRef); - Value *incy = load_if_ref(B1, IT, blasincy, byRef); - Value *alpha = blasalpha; - if (byRef) { - auto VP = B1.CreatePointerCast( - blasalpha, - PointerType::get( - fpTy, - cast(blasalpha->getType())->getAddressSpace())); - alpha = B1.CreateLoad(fpTy, VP); - } - Value *is_l = is_lower(B1, blasuplo, byRef, /*cublas*/ false); - B1.CreateCondBr(B1.CreateICmpEQ(n, ConstantInt::get(IT, 0)), end, init); - - IRBuilder<> B2(init); - Value *xfloat = B2.CreatePointerCast( - blasx, - PointerType::get( - fpTy, cast(blasx->getType())->getAddressSpace())); - Value *dyfloat = B2.CreatePointerCast( - blasdy, - PointerType::get( - fpTy, cast(blasdy->getType())->getAddressSpace())); - Value *dAPfloat = B2.CreatePointerCast( - blasdAP, - PointerType::get( - fpTy, cast(blasdAP->getType())->getAddressSpace())); - B2.CreateCondBr(is_l, lower_code, uper_code); - - IRBuilder<> B3(uper_code); - B3.setFastMathFlags(getFast()); - { - PHINode *iter = B3.CreatePHI(IT, 2, "iteration"); - PHINode *kval = B3.CreatePHI(IT, 2, "k"); - iter->addIncoming(ConstantInt::get(IT, 0), init); - kval->addIncoming(ConstantInt::get(IT, 0), init); - Value *iternext = - B3.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next"); - // 0, 2, 5, 9, 14, 20, 27, 35, 44, 54, ... are diag elements - Value *kvalnext = B3.CreateAdd(kval, iternext, "k.next"); - iter->addIncoming(iternext, uper_code); - kval->addIncoming(kvalnext, uper_code); - - Value *xidx = B3.CreateNUWMul(iter, incx, "x.idx"); - Value *yidx = B3.CreateNUWMul(iter, incy, "y.idx"); - Value *x = B3.CreateInBoundsGEP(fpTy, xfloat, xidx, "x.ptr"); - Value *y = B3.CreateInBoundsGEP(fpTy, dyfloat, yidx, "y.ptr"); - Value *xval = B3.CreateLoad(fpTy, x, "x.val"); - Value *yval = B3.CreateLoad(fpTy, y, "y.val"); - Value *xy = B3.CreateFMul(xval, yval, "xy"); - Value *xyalpha = B3.CreateFMul(xy, alpha, "xy.alpha"); - Value *kptr = B3.CreateInBoundsGEP(fpTy, dAPfloat, kval, "k.ptr"); - Value *kvalloaded = B3.CreateLoad(fpTy, kptr, "k.val"); - Value *kvalnew = B3.CreateFSub(kvalloaded, xyalpha, "k.val.new"); - B3.CreateStore(kvalnew, kptr); - - B3.CreateCondBr(B3.CreateICmpEQ(iternext, n), end, uper_code); - } - - IRBuilder<> B4(lower_code); - B4.setFastMathFlags(getFast()); - { - PHINode *iter = B4.CreatePHI(IT, 2, "iteration"); - PHINode *kval = B4.CreatePHI(IT, 2, "k"); - iter->addIncoming(ConstantInt::get(IT, 0), init); - kval->addIncoming(ConstantInt::get(IT, 0), init); - Value *iternext = - B4.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next"); - Value *ktmp = B4.CreateAdd(n, ConstantInt::get(IT, 1), "tmp.val"); - Value *ktmp2 = B4.CreateSub(ktmp, iternext, "tmp.val.other"); - Value *kvalnext = B4.CreateAdd(kval, ktmp2, "k.next"); - iter->addIncoming(iternext, lower_code); - kval->addIncoming(kvalnext, lower_code); - - Value *xidx = B4.CreateNUWMul(iter, incx, "x.idx"); - Value *yidx = B4.CreateNUWMul(iter, incy, "y.idx"); - Value *x = B4.CreateInBoundsGEP(fpTy, xfloat, xidx, "x.ptr"); - Value *y = B4.CreateInBoundsGEP(fpTy, dyfloat, yidx, "y.ptr"); - Value *xval = B4.CreateLoad(fpTy, x, "x.val"); - Value *yval = B4.CreateLoad(fpTy, y, "y.val"); - Value *xy = B4.CreateFMul(xval, yval, "xy"); - Value *xyalpha = B4.CreateFMul(xy, alpha, "xy.alpha"); - Value *kptr = B4.CreateInBoundsGEP(fpTy, dAPfloat, kval, "k.ptr"); - Value *kvalloaded = B4.CreateLoad(fpTy, kptr, "k.val"); - Value *kvalnew = B4.CreateFSub(kvalloaded, xyalpha, "k.val.new"); - B4.CreateStore(kvalnew, kptr); - - B4.CreateCondBr(B4.CreateICmpEQ(iternext, n), end, lower_code); - } - - IRBuilder<> B5(end); - B5.CreateRetVoid(); - } - B.CreateCall(F, args, bundles); - return; -} - -llvm::CallInst * -getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, - IntegerType *IT, Type *BlasPT, Type *BlasIT, Type *fpTy, - llvm::ArrayRef args, - const llvm::ArrayRef bundles, - bool byRef, bool cublas, bool julia_decl) { - assert(fpTy->isFloatingPointTy()); - - // add inner_prod call if not already present - std::string prod_name = "__enzyme_inner_prod" + blas.floatType + blas.suffix; - auto FInnerProdT = - FunctionType::get(fpTy, {BlasIT, BlasIT, BlasPT, BlasIT, BlasPT}, false); - Function *F = - cast(M.getOrInsertFunction(prod_name, FInnerProdT).getCallee()); - - if (!F->empty()) - return B.CreateCall(F, args, bundles); - - // add dot call if not already present - std::string dot_name = blas.prefix + blas.floatType + "dot" + blas.suffix; - auto FDotT = - FunctionType::get(fpTy, {BlasIT, BlasPT, BlasIT, BlasPT, BlasIT}, false); - auto FDot = M.getOrInsertFunction(dot_name, FDotT); - if (auto F = GetFunctionFromValue(FDot.getCallee())) - attributeKnownFunctions(*F); - - // now add the implementation for the inner_prod call - F->setLinkage(Function::LinkageTypes::InternalLinkage); -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesArgMemory(); - F->setOnlyReadsMemory(); -#else - F->addFnAttr(Attribute::ArgMemOnly); - F->addFnAttr(Attribute::ReadOnly); -#endif - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - if (!julia_decl) { - addFunctionNoCapture(F, 2); - addFunctionNoCapture(F, 2); - F->addParamAttr(2, Attribute::NoAlias); - F->addParamAttr(4, Attribute::NoAlias); - F->addParamAttr(2, Attribute::ReadOnly); - F->addParamAttr(4, Attribute::ReadOnly); - } - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *init = BasicBlock::Create(M.getContext(), "init.idx", F); - BasicBlock *fastPath = BasicBlock::Create(M.getContext(), "fast.path", F); - BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F); - BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); - - // This is the .td declaration which we need to match - // No need to support ld for the second matrix, as it will - // always be based on a matrix which we allocated (contiguous) - //(FrobInnerProd<> $m, $n, adj<"C">, $ldc, use<"AB">) - - auto blasm = F->arg_begin(); - blasm->setName("blasm"); - auto blasn = blasm + 1; - blasn->setName("blasn"); - auto matA = blasn + 1; - matA->setName("A"); - auto blaslda = matA + 1; - blaslda->setName("lda"); - auto matB = blaslda + 1; - matB->setName("B"); - - { - IRBuilder<> B1(entry); - Value *blasOne = to_blas_callconv(B1, ConstantInt::get(IT, 1), byRef, - cublas, nullptr, B1, "constant.one"); - - if (blasOne->getType() != BlasIT) - blasOne = B1.CreatePointerCast(blasOne, BlasIT, "intcast.constant.one"); - - Value *m = load_if_ref(B1, IT, blasm, byRef); - Value *n = load_if_ref(B1, IT, blasn, byRef); - Value *size = B1.CreateNUWMul(m, n, "mat.size"); - Value *blasSize = to_blas_callconv( - B1, size, byRef, cublas, julia_decl ? IT : nullptr, B1, "mat.size"); - - if (blasSize->getType() != BlasIT) - blasSize = B1.CreatePointerCast(blasSize, BlasIT, "intcast.mat.size"); - B1.CreateCondBr(B1.CreateICmpEQ(size, ConstantInt::get(IT, 0)), end, init); - - IRBuilder<> B2(init); - B2.setFastMathFlags(getFast()); - Value *lda = load_if_ref(B2, IT, blaslda, byRef); - Value *Afloat = B2.CreatePointerCast( - matA, PointerType::get( - fpTy, cast(matA->getType())->getAddressSpace())); - Value *Bfloat = B2.CreatePointerCast( - matB, PointerType::get( - fpTy, cast(matB->getType())->getAddressSpace())); - B2.CreateCondBr(B2.CreateICmpEQ(m, lda), fastPath, body); - - // our second matrix is always continuos, by construction. - // If our first matrix is continuous too (lda == m), then we can - // use a single dot call. - IRBuilder<> B3(fastPath); - B3.setFastMathFlags(getFast()); - Value *blasA = B3.CreatePointerCast(matA, BlasPT); - Value *blasB = B3.CreatePointerCast(matB, BlasPT); - Value *fastSum = - B3.CreateCall(FDot, {blasSize, blasA, blasOne, blasB, blasOne}); - B3.CreateBr(end); - - IRBuilder<> B4(body); - B4.setFastMathFlags(getFast()); - PHINode *Aidx = B4.CreatePHI(IT, 2, "Aidx"); - PHINode *Bidx = B4.CreatePHI(IT, 2, "Bidx"); - PHINode *iter = B4.CreatePHI(IT, 2, "iteration"); - PHINode *sum = B4.CreatePHI(fpTy, 2, "sum"); - Aidx->addIncoming(ConstantInt::get(IT, 0), init); - Bidx->addIncoming(ConstantInt::get(IT, 0), init); - iter->addIncoming(ConstantInt::get(IT, 0), init); - sum->addIncoming(ConstantFP::get(fpTy, 0.0), init); - - Value *Ai = B4.CreateInBoundsGEP(fpTy, Afloat, Aidx, "A.i"); - Value *Bi = B4.CreateInBoundsGEP(fpTy, Bfloat, Bidx, "B.i"); - Value *AiDot = B4.CreatePointerCast(Ai, BlasPT); - Value *BiDot = B4.CreatePointerCast(Bi, BlasPT); - Value *newDot = - B4.CreateCall(FDot, {blasm, AiDot, blasOne, BiDot, blasOne}); - - Value *Anext = B4.CreateNUWAdd(Aidx, lda, "Aidx.next"); - Value *Bnext = B4.CreateNUWAdd(Aidx, m, "Bidx.next"); - Value *iternext = B4.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next"); - Value *sumnext = B4.CreateFAdd(sum, newDot); - - iter->addIncoming(iternext, body); - Aidx->addIncoming(Anext, body); - Bidx->addIncoming(Bnext, body); - sum->addIncoming(sumnext, body); - - B4.CreateCondBr(B4.CreateICmpEQ(iter, n), end, body); - - IRBuilder<> B5(end); - PHINode *res = B5.CreatePHI(fpTy, 3, "res"); - res->addIncoming(ConstantFP::get(fpTy, 0.0), entry); - res->addIncoming(sum, body); - res->addIncoming(fastSum, fastPath); - B5.CreateRet(res); - } - - return B.CreateCall(F, args, bundles); -} - -Function *getOrInsertMemcpyStrided(Module &M, Type *elementType, PointerType *T, - Type *IT, unsigned dstalign, - unsigned srcalign) { - assert(elementType->isFloatingPointTy()); - std::string name = "__enzyme_memcpy_" + tofltstr(elementType) + "_" + - std::to_string(cast(IT)->getBitWidth()) + - "_da" + std::to_string(dstalign) + "sa" + - std::to_string(srcalign) + "stride"; - FunctionType *FT = - FunctionType::get(Type::getVoidTy(M.getContext()), {T, T, IT, IT}, false); - - Function *F = cast(M.getOrInsertFunction(name, FT).getCallee()); - - if (!F->empty()) - return F; - - F->setLinkage(Function::LinkageTypes::InternalLinkage); -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesArgMemory(); -#else - F->addFnAttr(Attribute::ArgMemOnly); -#endif - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - addFunctionNoCapture(F, 0); - F->addParamAttr(0, Attribute::NoAlias); - addFunctionNoCapture(F, 1); - F->addParamAttr(1, Attribute::NoAlias); - F->addParamAttr(0, Attribute::WriteOnly); - F->addParamAttr(1, Attribute::ReadOnly); - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *init = BasicBlock::Create(M.getContext(), "init.idx", F); - BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F); - BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); - - auto dst = F->arg_begin(); - dst->setName("dst"); - auto src = dst + 1; - src->setName("src"); - auto num = src + 1; - num->setName("num"); - auto stride = num + 1; - stride->setName("stride"); - - { - IRBuilder<> B(entry); - B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)), - end, init); - } - - { - IRBuilder<> B2(init); - B2.setFastMathFlags(getFast()); - Value *a = B2.CreateNSWSub(ConstantInt::get(num->getType(), 1), num, "a"); - Value *negidx = B2.CreateNSWMul(a, stride, "negidx"); - // Value *negidx = - // B2.CreateNSWAdd(b, ConstantInt::get(num->getType(), 1), - // "negidx"); - Value *isneg = - B2.CreateICmpSLT(stride, ConstantInt::get(num->getType(), 0), "is.neg"); - Value *startidx = B2.CreateSelect( - isneg, negidx, ConstantInt::get(num->getType(), 0), "startidx"); - B2.CreateBr(body); - //} - - //{ - IRBuilder<> B(body); - B.setFastMathFlags(getFast()); - PHINode *idx = B.CreatePHI(num->getType(), 2, "idx"); - PHINode *sidx = B.CreatePHI(num->getType(), 2, "sidx"); - idx->addIncoming(ConstantInt::get(num->getType(), 0), init); - sidx->addIncoming(startidx, init); - - Value *dsti = B.CreateInBoundsGEP(elementType, dst, idx, "dst.i"); - Value *srci = B.CreateInBoundsGEP(elementType, src, sidx, "src.i"); - LoadInst *srcl = B.CreateLoad(elementType, srci, "src.i.l"); - StoreInst *dsts = B.CreateStore(srcl, dsti); - - if (dstalign) { - dsts->setAlignment(Align(dstalign)); - } - if (srcalign) { - srcl->setAlignment(Align(srcalign)); - } - - Value *next = - B.CreateNSWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next"); - Value *snext = B.CreateNSWAdd(sidx, stride, "sidx.next"); - idx->addIncoming(next, body); - sidx->addIncoming(snext, body); - B.CreateCondBr(B.CreateICmpEQ(num, next), end, body); - } - - { - IRBuilder<> B(end); - B.CreateRetVoid(); - } - - return F; -} - -Function *getOrInsertMemcpyMat(Module &Mod, Type *elementType, PointerType *PT, - IntegerType *IT, unsigned dstalign, - unsigned srcalign) { - assert(elementType->isFPOrFPVectorTy()); -#if LLVM_VERSION_MAJOR < 17 -#if LLVM_VERSION_MAJOR >= 15 - if (Mod.getContext().supportsTypedPointers()) { -#endif -#if LLVM_VERSION_MAJOR >= 13 - if (!PT->isOpaquePointerTy()) -#endif - assert(PT->getPointerElementType() == elementType); -#if LLVM_VERSION_MAJOR >= 15 - } -#endif -#endif - std::string name = "__enzyme_memcpy_" + tofltstr(elementType) + "_mat_" + - std::to_string(cast(IT)->getBitWidth()); - FunctionType *FT = FunctionType::get(Type::getVoidTy(Mod.getContext()), - {PT, PT, IT, IT, IT}, false); - - Function *F = cast(Mod.getOrInsertFunction(name, FT).getCallee()); - - if (!F->empty()) - return F; - - F->setLinkage(Function::LinkageTypes::InternalLinkage); -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesArgMemory(); -#else - F->addFnAttr(Attribute::ArgMemOnly); -#endif - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - addFunctionNoCapture(F, 0); - F->addParamAttr(0, Attribute::NoAlias); - addFunctionNoCapture(F, 1); - F->addParamAttr(1, Attribute::NoAlias); - F->addParamAttr(0, Attribute::WriteOnly); - F->addParamAttr(1, Attribute::ReadOnly); - - BasicBlock *entry = BasicBlock::Create(F->getContext(), "entry", F); - BasicBlock *init = BasicBlock::Create(F->getContext(), "init.idx", F); - BasicBlock *body = BasicBlock::Create(F->getContext(), "for.body", F); - BasicBlock *initend = BasicBlock::Create(F->getContext(), "init.end", F); - BasicBlock *end = BasicBlock::Create(F->getContext(), "for.end", F); - - auto dst = F->arg_begin(); - dst->setName("dst"); - auto src = dst + 1; - src->setName("src"); - auto M = src + 1; - M->setName("M"); - auto N = M + 1; - N->setName("N"); - auto LDA = N + 1; - LDA->setName("LDA"); - - { - IRBuilder<> B(entry); - Value *l = B.CreateAdd(M, N, "mul", true, true); - // Don't copy a 0*0 matrix - B.CreateCondBr(B.CreateICmpEQ(l, ConstantInt::get(IT, 0)), end, init); - } - - PHINode *j; - { - IRBuilder<> B(init); - j = B.CreatePHI(IT, 2, "j"); - j->addIncoming(ConstantInt::get(IT, 0), entry); - B.CreateBr(body); - } - - { - IRBuilder<> B(body); - PHINode *i = B.CreatePHI(IT, 2, "i"); - i->addIncoming(ConstantInt::get(IT, 0), init); - - Value *dsti = B.CreateInBoundsGEP( - elementType, dst, - B.CreateAdd(i, B.CreateMul(j, M, "", true, true), "", true, true), - "dst.i"); - Value *srci = B.CreateInBoundsGEP( - elementType, src, - B.CreateAdd(i, B.CreateMul(j, LDA, "", true, true), "", true, true), - "dst.i"); - LoadInst *srcl = B.CreateLoad(elementType, srci, "src.i.l"); - - StoreInst *dsts = B.CreateStore(srcl, dsti); - - if (dstalign) { - dsts->setAlignment(Align(dstalign)); - } - if (srcalign) { - srcl->setAlignment(Align(srcalign)); - } - - Value *nexti = - B.CreateAdd(i, ConstantInt::get(IT, 1), "i.next", true, true); - i->addIncoming(nexti, body); - B.CreateCondBr(B.CreateICmpEQ(nexti, M), initend, body); - } - - { - IRBuilder<> B(initend); - Value *nextj = - B.CreateAdd(j, ConstantInt::get(IT, 1), "j.next", true, true); - j->addIncoming(nextj, initend); - B.CreateCondBr(B.CreateICmpEQ(nextj, N), end, init); - } - - { - IRBuilder<> B(end); - B.CreateRetVoid(); - } - - return F; -} - -// TODO implement differential memmove -Function * -getOrInsertDifferentialFloatMemmove(Module &M, Type *T, unsigned dstalign, - unsigned srcalign, unsigned dstaddr, - unsigned srcaddr, unsigned bitwidth) { - if (EnzymeMemmoveWarning) - llvm::errs() - << "warning: didn't implement memmove, using memcpy as fallback " - "which can result in errors\n"; - return getOrInsertDifferentialFloatMemcpy(M, T, dstalign, srcalign, dstaddr, - srcaddr, bitwidth); -} - -Function *getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty, - unsigned width) { - FunctionType *FreeTy = call->getFunctionType(); - Value *Free = call->getCalledOperand(); - AttributeList FreeAttributes = call->getAttributes(); - CallingConv::ID CallingConvention = call->getCallingConv(); - - std::string name = "__enzyme_checked_free_" + std::to_string(width); - - auto callname = getFuncNameFromCall(call); - if (callname != "free") - name += "_" + callname.str(); - - SmallVector types; - types.push_back(Ty); - for (unsigned i = 0; i < width; i++) { - types.push_back(Ty); - } -#if LLVM_VERSION_MAJOR >= 14 - for (size_t i = 1; i < call->arg_size(); i++) -#else - for (size_t i = 1; i < call->getNumArgOperands(); i++) -#endif - { - types.push_back(call->getArgOperand(i)->getType()); - } - - FunctionType *FT = - FunctionType::get(Type::getVoidTy(M.getContext()), types, false); - - Function *F = cast(M.getOrInsertFunction(name, FT).getCallee()); - - if (!F->empty()) - return F; - - F->setLinkage(Function::LinkageTypes::InternalLinkage); -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesArgMemory(); -#else - F->addFnAttr(Attribute::ArgMemOnly); -#endif - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *free0 = BasicBlock::Create(M.getContext(), "free0", F); - BasicBlock *end = BasicBlock::Create(M.getContext(), "end", F); - - IRBuilder<> EntryBuilder(entry); - IRBuilder<> Free0Builder(free0); - IRBuilder<> EndBuilder(end); - - auto primal = F->arg_begin(); - Argument *first_shadow = F->arg_begin() + 1; - addFunctionNoCapture(F, 0); - addFunctionNoCapture(F, 1); - - Value *isNotEqual = EntryBuilder.CreateICmpNE(primal, first_shadow); - EntryBuilder.CreateCondBr(isNotEqual, free0, end); - - SmallVector args = {first_shadow}; -#if LLVM_VERSION_MAJOR >= 14 - for (size_t i = 1; i < call->arg_size(); i++) -#else - for (size_t i = 1; i < call->getNumArgOperands(); i++) -#endif - { - args.push_back(F->arg_begin() + width + i); - } - - CallInst *CI = Free0Builder.CreateCall(FreeTy, Free, args); - CI->setAttributes(FreeAttributes); - CI->setCallingConv(CallingConvention); - - if (width > 1) { - Value *checkResult = nullptr; - BasicBlock *free1 = BasicBlock::Create(M.getContext(), "free1", F); - IRBuilder<> Free1Builder(free1); - - for (unsigned i = 0; i < width; i++) { - addFunctionNoCapture(F, i + 1); - Argument *shadow = F->arg_begin() + i + 1; - - if (i < width - 1) { - Argument *nextShadow = F->arg_begin() + i + 2; - Value *isNotEqual = Free0Builder.CreateICmpNE(shadow, nextShadow); - checkResult = checkResult - ? Free0Builder.CreateAnd(isNotEqual, checkResult) - : isNotEqual; - - args[0] = nextShadow; - CallInst *CI = Free1Builder.CreateCall(FreeTy, Free, args); - CI->setAttributes(FreeAttributes); - CI->setCallingConv(CallingConvention); - } - } - Free0Builder.CreateCondBr(checkResult, free1, end); - Free1Builder.CreateBr(end); - } else { - Free0Builder.CreateBr(end); - } - - EndBuilder.CreateRetVoid(); - - return F; -} - -/// Create function to computer nearest power of two -llvm::Value *nextPowerOfTwo(llvm::IRBuilder<> &B, llvm::Value *V) { - assert(V->getType()->isIntegerTy()); - IntegerType *T = cast(V->getType()); - V = B.CreateAdd(V, ConstantInt::get(T, -1)); - for (size_t i = 1; i < T->getBitWidth(); i *= 2) { - V = B.CreateOr(V, B.CreateLShr(V, ConstantInt::get(T, i))); - } - V = B.CreateAdd(V, ConstantInt::get(T, 1)); - return V; -} - -llvm::Function *getOrInsertDifferentialWaitallSave(llvm::Module &M, - ArrayRef T, - PointerType *reqType) { - std::string name = "__enzyme_differential_waitall_save"; - FunctionType *FT = - FunctionType::get(PointerType::getUnqual(reqType), T, false); - Function *F = cast(M.getOrInsertFunction(name, FT).getCallee()); - - if (!F->empty()) - return F; - - F->setLinkage(Function::LinkageTypes::InternalLinkage); - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - - auto buff = F->arg_begin(); - buff->setName("count"); - Value *count = buff; - Value *req = buff + 1; - req->setName("req"); - Value *dreq = buff + 2; - dreq->setName("dreq"); - - IRBuilder<> B(entry); - count = B.CreateZExtOrTrunc(count, Type::getInt64Ty(entry->getContext())); - - auto ret = CreateAllocation(B, reqType, count); - - BasicBlock *loopBlock = BasicBlock::Create(M.getContext(), "loop", F); - BasicBlock *endBlock = BasicBlock::Create(M.getContext(), "end", F); - - B.CreateCondBr(B.CreateICmpEQ(count, ConstantInt::get(count->getType(), 0)), - endBlock, loopBlock); - - B.SetInsertPoint(loopBlock); - auto idx = B.CreatePHI(count->getType(), 2); - idx->addIncoming(ConstantInt::get(count->getType(), 0), entry); - auto inc = B.CreateAdd(idx, ConstantInt::get(count->getType(), 1)); - idx->addIncoming(inc, loopBlock); - - Type *reqT = reqType; // req->getType()->getPointerElementType(); - Value *idxs[] = {idx}; - Value *ireq = B.CreateInBoundsGEP(reqT, req, idxs); - Value *idreq = B.CreateInBoundsGEP(reqT, dreq, idxs); - Value *iout = B.CreateInBoundsGEP(reqType, ret, idxs); - Value *isNull = nullptr; - if (auto GV = M.getNamedValue("ompi_request_null")) { - Value *reql = - B.CreatePointerCast(ireq, PointerType::getUnqual(GV->getType())); - reql = B.CreateLoad(GV->getType(), reql); - isNull = B.CreateICmpEQ(reql, GV); - } - - idreq = B.CreatePointerCast(idreq, PointerType::getUnqual(reqType)); - Value *d_reqp = B.CreateLoad(reqType, idreq); - if (isNull) - d_reqp = B.CreateSelect(isNull, Constant::getNullValue(d_reqp->getType()), - d_reqp); - - B.CreateStore(d_reqp, iout); - - B.CreateCondBr(B.CreateICmpEQ(inc, count), endBlock, loopBlock); - - B.SetInsertPoint(endBlock); - B.CreateRet(ret); - return F; -} - -llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M, - ArrayRef T, - Type *reqType) { - llvm::SmallVector types(T.begin(), T.end()); - types.push_back(reqType); - std::string name = "__enzyme_differential_mpi_wait"; - FunctionType *FT = - FunctionType::get(Type::getVoidTy(M.getContext()), types, false); - Function *F = cast(M.getOrInsertFunction(name, FT).getCallee()); - - if (!F->empty()) - return F; - - F->setLinkage(Function::LinkageTypes::InternalLinkage); - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *isend = BasicBlock::Create(M.getContext(), "invertISend", F); - BasicBlock *irecv = BasicBlock::Create(M.getContext(), "invertIRecv", F); - -#if 0 - /*0 */getInt8PtrTy(call.getContext()) - /*1 */i64 - /*2 */getInt8PtrTy(call.getContext()) - /*3 */i64 - /*4 */i64 - /*5 */getInt8PtrTy(call.getContext()) - /*6 */Type::getInt8Ty(call.getContext()) -#endif - - auto buff = F->arg_begin(); - buff->setName("buf"); - Value *buf = buff; - Value *count = buff + 1; - count->setName("count"); - Value *datatype = buff + 2; - datatype->setName("datatype"); - Value *source = buff + 3; - source->setName("source"); - Value *tag = buff + 4; - tag->setName("tag"); - Value *comm = buff + 5; - comm->setName("comm"); - Value *fn = buff + 6; - fn->setName("fn"); - Value *d_req = buff + 7; - d_req->setName("d_req"); - - bool pmpi = true; - auto isendfn = M.getFunction("PMPI_Isend"); - if (!isendfn) { - isendfn = M.getFunction("MPI_Isend"); - pmpi = false; - } - assert(isendfn); - auto irecvfn = M.getFunction("PMPI_Irecv"); - if (!irecvfn) - irecvfn = M.getFunction("MPI_Irecv"); - if (!irecvfn) { - FunctionType *FuT = isendfn->getFunctionType(); - std::string name = pmpi ? "PMPI_Irecv" : "MPI_Irecv"; - irecvfn = cast(M.getOrInsertFunction(name, FuT).getCallee()); - } - assert(irecvfn); - - IRBuilder<> B(entry); - auto arg = isendfn->arg_begin(); - if (arg->getType()->isIntegerTy()) - buf = B.CreatePtrToInt(buf, arg->getType()); - arg++; - count = B.CreateZExtOrTrunc(count, arg->getType()); - arg++; - datatype = B.CreatePointerCast(datatype, arg->getType()); - arg++; - source = B.CreateZExtOrTrunc(source, arg->getType()); - arg++; - tag = B.CreateZExtOrTrunc(tag, arg->getType()); - arg++; - comm = B.CreatePointerCast(comm, arg->getType()); - arg++; - if (arg->getType()->isIntegerTy()) - d_req = B.CreatePtrToInt(d_req, arg->getType()); - Value *args[] = { - buf, count, datatype, source, tag, comm, d_req, - }; - - B.CreateCondBr(B.CreateICmpEQ(fn, ConstantInt::get(fn->getType(), - (int)MPI_CallType::ISEND)), - isend, irecv); - - { - B.SetInsertPoint(isend); - auto fcall = B.CreateCall(irecvfn, args); - fcall->setCallingConv(isendfn->getCallingConv()); - B.CreateRetVoid(); - } - - { - B.SetInsertPoint(irecv); - auto fcall = B.CreateCall(isendfn, args); - fcall->setCallingConv(isendfn->getCallingConv()); - B.CreateRetVoid(); - } - return F; -} - -llvm::Value *getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr, - llvm::Type *OpType, ConcreteType CT, - llvm::Type *intType, IRBuilder<> &B2) { - std::string name = "__enzyme_mpi_sum" + CT.str(); - assert(CT.isFloat()); - auto FlT = CT.isFloat(); - - if (auto Glob = M.getGlobalVariable(name)) { - return B2.CreateLoad(Glob->getValueType(), Glob); - } - - llvm::Type *types[] = {PointerType::getUnqual(FlT), - PointerType::getUnqual(FlT), - PointerType::getUnqual(intType), OpPtr}; - FunctionType *FuT = - FunctionType::get(Type::getVoidTy(M.getContext()), types, false); - Function *F = - cast(M.getOrInsertFunction(name + "_run", FuT).getCallee()); - - F->setLinkage(Function::LinkageTypes::InternalLinkage); -#if LLVM_VERSION_MAJOR >= 16 - F->setOnlyAccessesArgMemory(); -#else - F->addFnAttr(Attribute::ArgMemOnly); -#endif - F->addFnAttr(Attribute::NoUnwind); - F->addFnAttr(Attribute::AlwaysInline); - addFunctionNoCapture(F, 0); - F->addParamAttr(0, Attribute::ReadOnly); - addFunctionNoCapture(F, 1); - addFunctionNoCapture(F, 2); - F->addParamAttr(2, Attribute::ReadOnly); - addFunctionNoCapture(F, 3); - F->addParamAttr(3, Attribute::ReadNone); - - BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); - BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F); - BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); - - auto src = F->arg_begin(); - src->setName("src"); - auto dst = src + 1; - dst->setName("dst"); - auto lenp = dst + 1; - lenp->setName("lenp"); - Value *len; - // TODO consider using datatype arg and asserting same size as assumed - // by type analysis - - { - IRBuilder<> B(entry); - len = B.CreateLoad(intType, lenp); - B.CreateCondBr(B.CreateICmpEQ(len, ConstantInt::get(len->getType(), 0)), - end, body); - } - - { - IRBuilder<> B(body); - B.setFastMathFlags(getFast()); - PHINode *idx = B.CreatePHI(len->getType(), 2, "idx"); - idx->addIncoming(ConstantInt::get(len->getType(), 0), entry); - - Value *dsti = B.CreateInBoundsGEP(FlT, dst, idx, "dst.i"); - LoadInst *dstl = B.CreateLoad(FlT, dsti, "dst.i.l"); - - Value *srci = B.CreateInBoundsGEP(FlT, src, idx, "src.i"); - LoadInst *srcl = B.CreateLoad(FlT, srci, "src.i.l"); - B.CreateStore(B.CreateFAdd(srcl, dstl), dsti); - - Value *next = - B.CreateNUWAdd(idx, ConstantInt::get(len->getType(), 1), "idx.next"); - idx->addIncoming(next, body); - B.CreateCondBr(B.CreateICmpEQ(len, next), end, body); - } - - { - IRBuilder<> B(end); - B.CreateRetVoid(); - } - - llvm::Type *rtypes[] = {getInt8PtrTy(M.getContext()), intType, OpPtr}; - FunctionType *RFT = FunctionType::get(intType, rtypes, false); - - Constant *RF = M.getNamedValue("MPI_Op_create"); - if (!RF) { - RF = - cast(M.getOrInsertFunction("MPI_Op_create", RFT).getCallee()); - } else { - RF = ConstantExpr::getBitCast(RF, PointerType::getUnqual(RFT)); - } - - GlobalVariable *GV = - new GlobalVariable(M, OpType, false, GlobalVariable::InternalLinkage, - UndefValue::get(OpType), name); - - Type *i1Ty = Type::getInt1Ty(M.getContext()); - GlobalVariable *initD = new GlobalVariable( - M, i1Ty, false, GlobalVariable::InternalLinkage, - ConstantInt::getFalse(M.getContext()), name + "_initd"); - - // Finish initializing mpi sum - // https://www.mpich.org/static/docs/v3.2/www3/MPI_Op_create.html - FunctionType *IFT = FunctionType::get(Type::getVoidTy(M.getContext()), - ArrayRef(), false); - Function *initializerFunction = cast( - M.getOrInsertFunction(name + "initializer", IFT).getCallee()); - - initializerFunction->setLinkage(Function::LinkageTypes::InternalLinkage); - initializerFunction->addFnAttr(Attribute::NoUnwind); - - { - BasicBlock *entry = - BasicBlock::Create(M.getContext(), "entry", initializerFunction); - BasicBlock *run = - BasicBlock::Create(M.getContext(), "run", initializerFunction); - BasicBlock *end = - BasicBlock::Create(M.getContext(), "end", initializerFunction); - IRBuilder<> B(entry); - - B.CreateCondBr(B.CreateLoad(initD->getValueType(), initD), end, run); - - B.SetInsertPoint(run); - Value *args[] = {ConstantExpr::getPointerCast(F, rtypes[0]), - ConstantInt::get(rtypes[1], 1, false), - ConstantExpr::getPointerCast(GV, rtypes[2])}; - B.CreateCall(RFT, RF, args); - B.CreateStore(ConstantInt::getTrue(M.getContext()), initD); - B.CreateBr(end); - B.SetInsertPoint(end); - B.CreateRetVoid(); - } - - B2.CreateCall(M.getFunction(name + "initializer")); - return B2.CreateLoad(GV->getValueType(), GV); -} - -void mayExecuteAfter(llvm::SmallVectorImpl &results, - llvm::Instruction *inst, - const llvm::SmallPtrSetImpl &stores, - const llvm::Loop *region) { - using namespace llvm; - std::map> maybeBlocks; - BasicBlock *instBlk = inst->getParent(); - for (auto store : stores) { - BasicBlock *storeBlk = store->getParent(); - if (instBlk == storeBlk) { - // if store doesn't come before, exit. - - if (store != inst) { - BasicBlock::const_iterator It = storeBlk->begin(); - for (; &*It != store && &*It != inst; ++It) - /*empty*/; - // if inst comes first (e.g. before store) in the - // block, return true - if (&*It == inst) { - results.push_back(store); - } - } - maybeBlocks[storeBlk].push_back(store); - } else { - maybeBlocks[storeBlk].push_back(store); - } - } - - if (maybeBlocks.size() == 0) - return; - - llvm::SmallVector todo; - for (auto B : successors(instBlk)) { - if (region && region->getHeader() == B) { - continue; - } - todo.push_back(B); - } - - SmallPtrSet seen; - while (todo.size()) { - auto cur = todo.back(); - todo.pop_back(); - if (seen.count(cur)) - continue; - seen.insert(cur); - auto found = maybeBlocks.find(cur); - if (found != maybeBlocks.end()) { - for (auto store : found->second) - results.push_back(store); - maybeBlocks.erase(found); - } - for (auto B : successors(cur)) { - if (region && region->getHeader() == B) { - continue; - } - todo.push_back(B); - } - } -} - -bool overwritesToMemoryReadByLoop( - llvm::ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT, - llvm::Instruction *maybeReader, const llvm::SCEV *LoadStart, - const llvm::SCEV *LoadEnd, llvm::Instruction *maybeWriter, - const llvm::SCEV *StoreStart, const llvm::SCEV *StoreEnd, - llvm::Loop *scope) { - // The store may either occur directly after the load in the current loop - // nest, or prior to the load in a subsequent iteration of the loop nest - // Generally: - // L0 -> scope -> L1 -> L2 -> L3 -> load_L4 -> load_L5 ... Load - // \-> store_L4 -> store_L5 ... Store - // We begin by finding the common ancestor of the two loops, which may - // be none. - Loop *anc = getAncestor(LI.getLoopFor(maybeReader->getParent()), - LI.getLoopFor(maybeWriter->getParent())); - - // The surrounding scope must contain the ancestor - if (scope) { - assert(anc); - assert(scope == anc || scope->contains(anc)); - } - - // Consider the case where the load and store don't share any common loops. - // That is to say, there's no loops in [scope, ancestor) we need to consider - // having a store in a later iteration overwrite the load of a previous - // iteration. - // - // An example of this overwriting would be a "left shift" - // for (int j = 1; j visitedAncestors; - auto skipLoop = [&](const Loop *L) { - assert(L); - if (scope && L->contains(scope)) - return false; - - if (anc && (anc == L || anc->contains(L))) { - visitedAncestors.insert(L); - return true; - } - return false; - }; - - // Check the boounds of an [... endprev][startnext ...] for potential - // overlaps. The boolean EndIsStore is true of the EndPev represents - // the store and should have its loops expanded, or if that should - // apply to StartNed. - auto hasOverlap = [&](const SCEV *EndPrev, const SCEV *StartNext, - bool EndIsStore) { - for (auto slim = StartNext; slim != SE.getCouldNotCompute();) { - bool sskip = false; - if (!EndIsStore) - if (auto startL = dyn_cast(slim)) - if (skipLoop(startL->getLoop()) && - SE.isKnownNonPositive(startL->getStepRecurrence(SE))) { - sskip = true; - } - - if (!sskip) - for (auto elim = EndPrev; elim != SE.getCouldNotCompute();) { - { - - bool eskip = false; - if (EndIsStore) - if (auto endL = dyn_cast(elim)) { - if (skipLoop(endL->getLoop()) && - SE.isKnownNonNegative(endL->getStepRecurrence(SE))) { - eskip = true; - } - } - - // Moreover because otherwise SE cannot "groupScevByComplexity" - // we need to ensure that if both slim/elim are AddRecv - // they must be in the same loop, or one loop must dominate - // the other. - if (!eskip) { - - if (auto endL = dyn_cast(elim)) { - auto EH = endL->getLoop()->getHeader(); - if (auto startL = dyn_cast(slim)) { - auto SH = startL->getLoop()->getHeader(); - if (EH != SH && !DT.dominates(EH, SH) && - !DT.dominates(SH, EH)) - eskip = true; - } - } - } - if (!eskip) { - auto sub = SE.getMinusSCEV(slim, elim); - if (sub != SE.getCouldNotCompute() && SE.isKnownNonNegative(sub)) - return false; - } - } - - if (auto endL = dyn_cast(elim)) { - if (SE.isKnownNonPositive(endL->getStepRecurrence(SE))) { - elim = endL->getStart(); - continue; - } else if (SE.isKnownNonNegative(endL->getStepRecurrence(SE))) { -#if LLVM_VERSION_MAJOR >= 12 - auto ebd = SE.getSymbolicMaxBackedgeTakenCount(endL->getLoop()); -#else - auto ebd = SE.getBackedgeTakenCount(endL->getLoop()); -#endif - if (ebd == SE.getCouldNotCompute()) - break; - elim = endL->evaluateAtIteration(ebd, SE); - continue; - } - } - break; - } - - if (auto startL = dyn_cast(slim)) { - if (SE.isKnownNonNegative(startL->getStepRecurrence(SE))) { - slim = startL->getStart(); - continue; - } else if (SE.isKnownNonPositive(startL->getStepRecurrence(SE))) { -#if LLVM_VERSION_MAJOR >= 12 - auto sbd = SE.getSymbolicMaxBackedgeTakenCount(startL->getLoop()); -#else - auto sbd = SE.getBackedgeTakenCount(startL->getLoop()); -#endif - if (sbd == SE.getCouldNotCompute()) - break; - slim = startL->evaluateAtIteration(sbd, SE); - continue; - } - } - break; - } - return true; - }; - - // There is no overwrite if either the stores all occur before the loads - // [S, S+Size][start load, L+Size] - visitedAncestors.clear(); - if (!hasOverlap(StoreEnd, LoadStart, /*EndIsStore*/ true)) { - // We must have seen all common loops as induction variables - // to be legal, lest we have a repetition of the store. - bool legal = true; - for (const Loop *L = anc; anc != scope; anc = anc->getParentLoop()) { - if (!visitedAncestors.count(L)) - legal = false; - } - if (legal) - return false; - } - - // There is no overwrite if either the loads all occur before the stores - // [start load, L+Size] [S, S+Size] - visitedAncestors.clear(); - if (!hasOverlap(LoadEnd, StoreStart, /*EndIsStore*/ false)) { - // We must have seen all common loops as induction variables - // to be legal, lest we have a repetition of the store. - bool legal = true; - for (const Loop *L = anc; anc != scope; anc = anc->getParentLoop()) { - if (!visitedAncestors.count(L)) - legal = false; - } - if (legal) - return false; - } - return true; -} - -bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, - llvm::TargetLibraryInfo &TLI, ScalarEvolution &SE, - llvm::LoopInfo &LI, llvm::DominatorTree &DT, - llvm::Instruction *maybeReader, - llvm::Instruction *maybeWriter, - llvm::Loop *scope) { - using namespace llvm; - if (!writesToMemoryReadBy(TR, AA, TLI, maybeReader, maybeWriter)) - return false; - const SCEV *LoadBegin = SE.getCouldNotCompute(); - const SCEV *LoadEnd = SE.getCouldNotCompute(); - - const SCEV *StoreBegin = SE.getCouldNotCompute(); - const SCEV *StoreEnd = SE.getCouldNotCompute(); - - Value *loadPtr = nullptr; - Value *storePtr = nullptr; - if (auto LI = dyn_cast(maybeReader)) { - loadPtr = LI->getPointerOperand(); - LoadBegin = SE.getSCEV(LI->getPointerOperand()); - if (LoadBegin != SE.getCouldNotCompute() && - !LoadBegin->getType()->isIntegerTy()) { - auto &DL = maybeWriter->getModule()->getDataLayout(); - auto width = cast(DL.getIndexType(LoadBegin->getType())) - ->getBitWidth(); -#if LLVM_VERSION_MAJOR >= 18 - auto TS = SE.getConstant( - APInt(width, (int64_t)DL.getTypeStoreSize(LI->getType()))); -#else - auto TS = SE.getConstant( - APInt(width, DL.getTypeStoreSize(LI->getType()).getFixedSize())); -#endif - LoadEnd = SE.getAddExpr(LoadBegin, TS); - } - } - if (auto SI = dyn_cast(maybeWriter)) { - storePtr = SI->getPointerOperand(); - StoreBegin = SE.getSCEV(SI->getPointerOperand()); - if (StoreBegin != SE.getCouldNotCompute() && - !StoreBegin->getType()->isIntegerTy()) { - auto &DL = maybeWriter->getModule()->getDataLayout(); - auto width = cast(DL.getIndexType(StoreBegin->getType())) - ->getBitWidth(); -#if LLVM_VERSION_MAJOR >= 18 - auto TS = - SE.getConstant(APInt(width, (int64_t)DL.getTypeStoreSize( - SI->getValueOperand()->getType()))); -#else - auto TS = SE.getConstant( - APInt(width, DL.getTypeStoreSize(SI->getValueOperand()->getType()) - .getFixedSize())); -#endif - StoreEnd = SE.getAddExpr(StoreBegin, TS); - } - } - if (auto MS = dyn_cast(maybeWriter)) { - storePtr = MS->getArgOperand(0); - StoreBegin = SE.getSCEV(MS->getArgOperand(0)); - if (StoreBegin != SE.getCouldNotCompute() && - !StoreBegin->getType()->isIntegerTy()) { - if (auto Len = dyn_cast(MS->getArgOperand(2))) { - auto &DL = MS->getModule()->getDataLayout(); - auto width = cast(DL.getIndexType(StoreBegin->getType())) - ->getBitWidth(); - auto TS = - SE.getConstant(APInt(width, Len->getValue().getLimitedValue())); - StoreEnd = SE.getAddExpr(StoreBegin, TS); - } - } - } - if (auto MS = dyn_cast(maybeWriter)) { - storePtr = MS->getArgOperand(0); - StoreBegin = SE.getSCEV(MS->getArgOperand(0)); - if (StoreBegin != SE.getCouldNotCompute() && - !StoreBegin->getType()->isIntegerTy()) { - if (auto Len = dyn_cast(MS->getArgOperand(2))) { - auto &DL = MS->getModule()->getDataLayout(); - auto width = cast(DL.getIndexType(StoreBegin->getType())) - ->getBitWidth(); - auto TS = - SE.getConstant(APInt(width, Len->getValue().getLimitedValue())); - StoreEnd = SE.getAddExpr(StoreBegin, TS); - } - } - } - if (auto MS = dyn_cast(maybeReader)) { - loadPtr = MS->getArgOperand(1); - LoadBegin = SE.getSCEV(MS->getArgOperand(1)); - if (LoadBegin != SE.getCouldNotCompute() && - !LoadBegin->getType()->isIntegerTy()) { - if (auto Len = dyn_cast(MS->getArgOperand(2))) { - auto &DL = MS->getModule()->getDataLayout(); - auto width = cast(DL.getIndexType(LoadBegin->getType())) - ->getBitWidth(); - auto TS = - SE.getConstant(APInt(width, Len->getValue().getLimitedValue())); - LoadEnd = SE.getAddExpr(LoadBegin, TS); - } - } - } - - if (loadPtr && storePtr) - if (auto alias = - arePointersGuaranteedNoAlias(TLI, AA, LI, loadPtr, storePtr, true)) - if (*alias) - return false; - - if (!overwritesToMemoryReadByLoop(SE, LI, DT, maybeReader, LoadBegin, LoadEnd, - maybeWriter, StoreBegin, StoreEnd, scope)) - return false; - - return true; -} - -/// Return whether maybeReader can read from memory written to by maybeWriter -bool writesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, - llvm::TargetLibraryInfo &TLI, - llvm::Instruction *maybeReader, - llvm::Instruction *maybeWriter) { - assert(maybeReader->getParent()->getParent() == - maybeWriter->getParent()->getParent()); - using namespace llvm; - if (isa(maybeReader)) - return false; - if (auto call = dyn_cast(maybeWriter)) { - StringRef funcName = getFuncNameFromCall(call); - - if (isDebugFunction(call->getCalledFunction())) - return false; - - if (isCertainPrint(funcName) || isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - return false; - } - - if (isMemFreeLibMFunction(funcName)) { - return false; - } - if (funcName == "jl_array_copy" || funcName == "ijl_array_copy") - return false; - - if (funcName == "jl_genericmemory_copy_slice" || - funcName == "ijl_genericmemory_copy_slice") - return false; - - if (funcName == "jl_new_array" || funcName == "ijl_new_array") - return false; - - if (funcName == "julia.safepoint") - return false; - - if (funcName == "jl_idtable_rehash" || funcName == "ijl_idtable_rehash") - return false; - - // Isend only writes to inaccessible mem only - if (funcName == "MPI_Send" || funcName == "PMPI_Send") { - return false; - } - // Wait only overwrites memory in the status and request. - if (funcName == "MPI_Wait" || funcName == "PMPI_Wait" || - funcName == "MPI_Waitall" || funcName == "PMPI_Waitall") { -#if LLVM_VERSION_MAJOR > 11 - auto loc = LocationSize::afterPointer(); -#else - auto loc = MemoryLocation::UnknownSize; -#endif - size_t off = (funcName == "MPI_Wait" || funcName == "PMPI_Wait") ? 0 : 1; - // No alias with status - if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(off + 1), - loc))) { - // No alias with request - if (!isRefSet(AA.getModRefInfo(maybeReader, - call->getArgOperand(off + 0), loc))) - return false; - auto R = parseTBAA( - *maybeReader, - maybeReader->getParent()->getParent()->getParent()->getDataLayout(), - nullptr)[{-1}]; - // Could still conflict with the mpi_request unless a non pointer - // type. - if (R != BaseType::Unknown && R != BaseType::Anything && - R != BaseType::Pointer) - return false; - } - } - // Isend only writes to inaccessible mem and request. - if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") { - auto R = parseTBAA( - *maybeReader, - maybeReader->getParent()->getParent()->getParent()->getDataLayout(), - nullptr)[{-1}]; - // Could still conflict with the mpi_request, unless either - // synchronous, or a non pointer type. - if (R != BaseType::Unknown && R != BaseType::Anything && - R != BaseType::Pointer) - return false; -#if LLVM_VERSION_MAJOR > 11 - if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6), - LocationSize::afterPointer()))) - return false; -#else - if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6), - MemoryLocation::UnknownSize))) - return false; -#endif - return false; - } - if (funcName == "MPI_Irecv" || funcName == "PMPI_Irecv" || - funcName == "MPI_Recv" || funcName == "PMPI_Recv") { - ConcreteType type(BaseType::Unknown); - if (Constant *C = dyn_cast(call->getArgOperand(2))) { - while (ConstantExpr *CE = dyn_cast(C)) { - C = CE->getOperand(0); - } - if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { - type = ConcreteType(Type::getDoubleTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_float") { - type = ConcreteType(Type::getFloatTy(C->getContext())); - } - } - } - if (type.isKnown()) { - auto R = parseTBAA( - *maybeReader, - maybeReader->getParent()->getParent()->getParent()->getDataLayout(), - nullptr)[{-1}]; - if (R.isKnown() && type != R) { - // Could still conflict with the mpi_request, unless either - // synchronous, or a non pointer type. - if (funcName == "MPI_Recv" || funcName == "PMPI_Recv" || - (R != BaseType::Anything && R != BaseType::Pointer)) - return false; -#if LLVM_VERSION_MAJOR > 11 - if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6), - LocationSize::afterPointer()))) - return false; -#else - if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6), - MemoryLocation::UnknownSize))) - return false; -#endif - } - } - } - if (auto II = dyn_cast(call)) { - if (II->getIntrinsicID() == Intrinsic::stacksave) - return false; - if (II->getIntrinsicID() == Intrinsic::stackrestore) - return false; - if (II->getIntrinsicID() == Intrinsic::trap) - return false; -#if LLVM_VERSION_MAJOR >= 13 - if (II->getIntrinsicID() == Intrinsic::experimental_noalias_scope_decl) - return false; -#endif - } - - if (auto iasm = dyn_cast(call->getCalledOperand())) { - if (StringRef(iasm->getAsmString()).contains("exit")) - return false; - } - } - if (auto call = dyn_cast(maybeReader)) { - StringRef funcName = getFuncNameFromCall(call); - - if (isDebugFunction(call->getCalledFunction())) - return false; - - if (isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - return false; - } - - if (isMemFreeLibMFunction(funcName)) { - return false; - } - - if (auto II = dyn_cast(call)) { - if (II->getIntrinsicID() == Intrinsic::stacksave) - return false; - if (II->getIntrinsicID() == Intrinsic::stackrestore) - return false; - if (II->getIntrinsicID() == Intrinsic::trap) - return false; -#if LLVM_VERSION_MAJOR >= 13 - if (II->getIntrinsicID() == Intrinsic::experimental_noalias_scope_decl) - return false; -#endif - } - } - if (auto call = dyn_cast(maybeWriter)) { - StringRef funcName = getFuncNameFromCall(call); - - if (isDebugFunction(call->getCalledFunction())) - return false; - - if (isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - return false; - } - - if (isMemFreeLibMFunction(funcName)) { - return false; - } - if (funcName == "jl_array_copy" || funcName == "ijl_array_copy") - return false; - - if (funcName == "jl_genericmemory_copy_slice" || - funcName == "ijl_genericmemory_copy_slice") - return false; - - if (funcName == "jl_idtable_rehash" || funcName == "ijl_idtable_rehash") - return false; - - if (auto iasm = dyn_cast(call->getCalledOperand())) { - if (StringRef(iasm->getAsmString()).contains("exit")) - return false; - } - } - if (auto call = dyn_cast(maybeReader)) { - StringRef funcName = getFuncNameFromCall(call); - - if (isDebugFunction(call->getCalledFunction())) - return false; - - if (isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - return false; - } - - if (isMemFreeLibMFunction(funcName)) { - return false; - } - } - assert(maybeWriter->mayWriteToMemory()); - assert(maybeReader->mayReadFromMemory()); - - if (auto li = dyn_cast(maybeReader)) { - if (TR) { - auto TT = TR->query(li)[{-1}]; - if (TT != BaseType::Unknown && TT != BaseType::Anything) { - if (auto si = dyn_cast(maybeWriter)) { - auto TT2 = TR->query(si->getValueOperand())[{-1}]; - if (TT2 != BaseType::Unknown && TT2 != BaseType::Anything) { - if (TT != TT2) - return false; - } - auto &dl = li->getParent()->getParent()->getParent()->getDataLayout(); - auto len = - (dl.getTypeSizeInBits(si->getValueOperand()->getType()) + 7) / 8; - TT2 = TR->query(si->getPointerOperand()).Lookup(len, dl)[{-1}]; - if (TT2 != BaseType::Unknown && TT2 != BaseType::Anything) { - if (TT != TT2) - return false; - } - } - } - } - return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(li))); - } - if (auto rmw = dyn_cast(maybeReader)) { - return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(rmw))); - } - if (auto xch = dyn_cast(maybeReader)) { - return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(xch))); - } - if (auto mti = dyn_cast(maybeReader)) { - return isModSet( - AA.getModRefInfo(maybeWriter, MemoryLocation::getForSource(mti))); - } - - if (auto si = dyn_cast(maybeWriter)) { - return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(si))); - } - if (auto rmw = dyn_cast(maybeWriter)) { - return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(rmw))); - } - if (auto xch = dyn_cast(maybeWriter)) { - return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(xch))); - } - if (auto mti = dyn_cast(maybeWriter)) { - return isRefSet( - AA.getModRefInfo(maybeReader, MemoryLocation::getForDest(mti))); - } - - if (auto cb = dyn_cast(maybeReader)) { - return isModOrRefSet(AA.getModRefInfo(maybeWriter, cb)); - } - if (auto cb = dyn_cast(maybeReader)) { - return isModOrRefSet(AA.getModRefInfo(maybeWriter, cb)); - } - llvm::errs() << " maybeReader: " << *maybeReader - << " maybeWriter: " << *maybeWriter << "\n"; - llvm_unreachable("unknown inst2"); -} - -// Find the base pointer of ptr and the offset in bytes from the start of -// the returned base pointer to this value. -AllocaInst *getBaseAndOffset(Value *ptr, size_t &offset) { - offset = 0; - while (true) { - if (auto CI = dyn_cast(ptr)) { - ptr = CI->getOperand(0); - continue; - } - if (auto CI = dyn_cast(ptr)) { - auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); -#if LLVM_VERSION_MAJOR >= 20 - SmallMapVector VariableOffsets; -#else - MapVector VariableOffsets; -#endif - auto width = sizeof(size_t) * 8; - APInt Offset(width, 0); - bool success = collectOffset(cast(CI), DL, width, - VariableOffsets, Offset); - if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { - return nullptr; - } - offset += Offset.getZExtValue(); - ptr = CI->getOperand(0); - continue; - } - if (isa(ptr)) { - break; - } - if (auto LI = dyn_cast(ptr)) { - if (auto S = simplifyLoad(LI)) { - ptr = S; - continue; - } - } - return nullptr; - } - return cast(ptr); -} - -// Find all user instructions of AI, returning tuples of Unlike a simple get users, this will recurse through any -// constant gep offsets and casts -SmallVector, 1> -findAllUsersOf(Value *AI) { - SmallVector, 1> todo; - todo.emplace_back(AI, 0); - - SmallVector, 1> users; - while (todo.size()) { - auto pair = todo.pop_back_val(); - Value *ptr = pair.first; - size_t suboff = pair.second; - - for (auto U : ptr->users()) { - if (auto CI = dyn_cast(U)) { - todo.emplace_back(CI, suboff); - continue; - } - if (auto CI = dyn_cast(U)) { - auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); -#if LLVM_VERSION_MAJOR >= 20 - SmallMapVector VariableOffsets; -#else - MapVector VariableOffsets; -#endif - auto width = sizeof(size_t) * 8; - APInt Offset(width, 0); - bool success = collectOffset(cast(CI), DL, width, - VariableOffsets, Offset); - - if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { - users.emplace_back(cast(U), ptr, suboff); - continue; - } - todo.emplace_back(CI, suboff + Offset.getZExtValue()); - continue; - } - users.emplace_back(cast(U), ptr, suboff); - continue; - } - } - return users; -} - -// Given a pointer, find all values of size `valSz` which could be loaded from -// that pointer when indexed at offset. If it is impossible to guarantee that -// the set contains all such values, set legal to false -SmallVector, 1> -getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, size_t valSz, - bool &legal) { - SmallVector, 1> options; - - auto todo = findAllUsersOf(ptr0); - std::set> seen; - - while (todo.size()) { - auto pair = todo.pop_back_val(); - if (seen.count(pair)) - continue; - seen.insert(pair); - Instruction *U = std::get<0>(pair); - Value *ptr = std::get<1>(pair); - size_t suboff = std::get<2>(pair); - - // Read only users do not set the memory inside of ptr - if (isa(U)) { - continue; - } - if (auto MTI = dyn_cast(U)) - if (MTI->getOperand(0) != ptr) { - continue; - } - if (auto I = dyn_cast(U)) { - if (!I->mayWriteToMemory() && I->getType()->isVoidTy()) - continue; - } - - if (auto SI = dyn_cast(U)) { - auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout(); - - // We are storing into the ptr - if (SI->getPointerOperand() == ptr) { - auto storeSz = - (DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) / - 8; - // If store is before the load would start - if (storeSz + suboff <= offset) - continue; - // if store starts after load would start - if (offset + valSz <= suboff) - continue; - - if (valSz <= storeSz) { - assert(offset >= suboff); - options.emplace_back(SI->getValueOperand(), offset - suboff); - continue; - } - } - - // We capture our pointer of interest, if it is stored into an alloca, - // all loads of said alloca would potentially store into. - if (SI->getValueOperand() == ptr) { - if (suboff == 0) { - size_t mid_offset = 0; - if (auto AI2 = - getBaseAndOffset(SI->getPointerOperand(), mid_offset)) { - bool sublegal = true; - auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8; - auto subPtrs = - getAllLoadedValuesFrom(AI2, mid_offset, ptrSz, sublegal); - if (!sublegal) { - legal = false; - return options; - } - for (auto &&[subPtr, subOff] : subPtrs) { - if (subOff != 0) - return options; - for (const auto &pair3 : findAllUsersOf(subPtr)) { - todo.emplace_back(std::move(pair3)); - } - } - continue; - } - } - } - } - - if (auto II = dyn_cast(U)) { - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end) - continue; - } - - // If we copy into the ptr at a location that includes the offset, consider - // all sub uses - if (auto MTI = dyn_cast(U)) { - if (auto CI = dyn_cast(MTI->getLength())) { - if (MTI->getOperand(0) == ptr) { - auto storeSz = CI->getValue(); - - // If store is before the load would start - if ((storeSz + suboff).ule(offset)) - continue; - - // if store starts after load would start - if (offset + valSz <= suboff) - continue; - - if (suboff == 0 && CI->getValue().uge(offset + valSz)) { - size_t midoffset = 0; - auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset); - if (!AI2) { - legal = false; - return options; - } - if (midoffset != 0) { - legal = false; - return options; - } - for (const auto &pair3 : findAllUsersOf(AI2)) { - todo.emplace_back(std::move(pair3)); - } - continue; - } - } - } - } - - legal = false; - return options; - } - - return options; -} - -// Perform mem2reg/sroa to identify the innermost value being represented. -Value *simplifyLoad(Value *V, size_t valSz, size_t preOffset) { - if (auto LI = dyn_cast(V)) { - if (valSz == 0) { - auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout(); - valSz = (DL.getTypeSizeInBits(LI->getType()) + 7) / 8; - } - - Value *ptr = LI->getPointerOperand(); - size_t offset = 0; - - if (auto ptr2 = simplifyLoad(ptr)) { - ptr = ptr2; - } - auto AI = getBaseAndOffset(ptr, offset); - if (!AI) { - return nullptr; - } - offset += preOffset; - - bool legal = true; - auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal); - - if (!legal) { - return nullptr; - } - std::set res; - for (auto &&[opt, startOff] : opts) { - Value *v2 = simplifyLoad(opt, valSz, startOff); - if (v2) - res.insert(v2); - else - res.insert(opt); - } - if (res.size() != 1) { - return nullptr; - } - Value *retval = *res.begin(); - return retval; - } - if (auto EVI = dyn_cast(V)) { - IRBuilder<> B(EVI); - auto em = - GradientUtils::extractMeta(B, EVI->getAggregateOperand(), - EVI->getIndices(), "", /*fallback*/ false); - if (em != nullptr) { - if (auto SL2 = simplifyLoad(em, valSz)) - em = SL2; - return em; - } - if (auto LI = dyn_cast(EVI->getAggregateOperand())) { - auto offset = preOffset; - - auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout(); - SmallVector vec; - vec.push_back(ConstantInt::get(Type::getInt64Ty(EVI->getContext()), 0)); - for (auto ind : EVI->getIndices()) { - vec.push_back( - ConstantInt::get(Type::getInt32Ty(EVI->getContext()), ind)); - } - auto ud = UndefValue::get( - PointerType::getUnqual(EVI->getOperand(0)->getType())); - auto g2 = - GetElementPtrInst::Create(EVI->getOperand(0)->getType(), ud, vec); - APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(DL, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - delete g2; - - offset += (size_t)ai.getLimitedValue(); - - if (valSz == 0) { - auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout(); - valSz = (DL.getTypeSizeInBits(EVI->getType()) + 7) / 8; - } - return simplifyLoad(LI, valSz, offset); - } - } - return nullptr; -} - -Value *GetFunctionValFromValue(Value *fn) { - while (!isa(fn)) { - if (auto ci = dyn_cast(fn)) { - fn = ci->getOperand(0); - continue; - } - if (auto ci = dyn_cast(fn)) { - if (ci->isCast()) { - fn = ci->getOperand(0); - continue; - } - } - if (auto ci = dyn_cast(fn)) { - fn = ci->getFunction(); - continue; - } - if (auto *GA = dyn_cast(fn)) { - fn = GA->getAliasee(); - continue; - } - if (auto *Call = dyn_cast(fn)) { - if (auto F = Call->getCalledFunction()) { - SmallPtrSet ret; - for (auto &BB : *F) { - if (auto RI = dyn_cast(BB.getTerminator())) { - ret.insert(RI->getReturnValue()); - } - } - if (ret.size() == 1) { - auto val = *ret.begin(); - val = GetFunctionValFromValue(val); - if (isa(val)) { - fn = val; - continue; - } - if (auto arg = dyn_cast(val)) { - fn = Call->getArgOperand(arg->getArgNo()); - continue; - } - } - } - } - if (auto *Call = dyn_cast(fn)) { - if (auto F = Call->getCalledFunction()) { - SmallPtrSet ret; - for (auto &BB : *F) { - if (auto RI = dyn_cast(BB.getTerminator())) { - ret.insert(RI->getReturnValue()); - } - } - if (ret.size() == 1) { - auto val = *ret.begin(); - while (isa(val)) { - auto v2 = simplifyLoad(val); - if (v2) { - val = v2; - continue; - } - break; - } - if (isa(val)) { - fn = val; - continue; - } - if (auto arg = dyn_cast(val)) { - fn = Call->getArgOperand(arg->getArgNo()); - continue; - } - } - } - } - if (auto S = simplifyLoad(fn)) { - fn = S; - continue; - } - break; - } - - return fn; -} - -Function *GetFunctionFromValue(Value *fn) { - return dyn_cast(GetFunctionValFromValue(fn)); -} - -#if LLVM_VERSION_MAJOR >= 16 -std::optional extractBLAS(llvm::StringRef in) -#else -llvm::Optional extractBLAS(llvm::StringRef in) -#endif -{ - const char *extractable[] = { - "dot", "scal", "axpy", "gemv", "gemm", "spmv", "syrk", "nrm2", - "trmm", "trmv", "symm", "potrf", "potrs", "copy", "spmv", "syr2k", - "potrs", "getrf", "getrs", "trtrs", "getri", "symv", - }; - const char *floatType[] = {"s", "d", "c", "z"}; - const char *prefixes[] = {"" /*Fortran*/, "cblas_"}; - const char *suffixes[] = {"", "_", "64_", "_64_"}; - for (auto t : floatType) { - for (auto f : extractable) { - for (auto p : prefixes) { - for (auto s : suffixes) { - if (in == (Twine(p) + t + f + s).str()) { - bool is64 = llvm::StringRef(s).contains("64"); - return BlasInfo{ - t, p, s, f, is64, - }; - } - } - } - } - } - // c interface to cublas - const char *cuCFloatType[] = {"S", "D", "C", "Z"}; - const char *cuFFloatType[] = {"s", "d", "c", "z"}; - const char *cuCPrefixes[] = {"cublas"}; - const char *cuSuffixes[] = {"", "_v2", "_64", "_v2_64"}; - for (auto t : llvm::enumerate(cuCFloatType)) { - for (auto f : extractable) { - for (auto p : cuCPrefixes) { - for (auto s : cuSuffixes) { - if (in == (Twine(p) + t.value() + f + s).str()) { - bool is64 = llvm::StringRef(s).contains("64"); - return BlasInfo{ - t.value(), p, s, f, is64, - }; - } - } - } - } - } - // Fortran interface to cublas - const char *cuFPrefixes[] = {"cublas_"}; - for (auto t : cuFFloatType) { - for (auto f : extractable) { - for (auto p : cuFPrefixes) { - if (in == (Twine(p) + t + f).str()) { - return BlasInfo{ - t, p, "", f, false, - }; - } - } - } - } - return {}; -} - -llvm::Constant *getUndefinedValueForType(llvm::Module &M, llvm::Type *T, - bool forceZero) { - if (EnzymeUndefinedValueForType) - return cast( - unwrap(EnzymeUndefinedValueForType(wrap(&M), wrap(T), forceZero))); - else if (EnzymeZeroCache || forceZero) - return Constant::getNullValue(T); - else - return UndefValue::get(T); -} - -llvm::Value *SanitizeDerivatives(llvm::Value *val, llvm::Value *toset, - llvm::IRBuilder<> &BuilderM, - llvm::Value *mask) { - if (EnzymeSanitizeDerivatives) - return unwrap(EnzymeSanitizeDerivatives(wrap(val), wrap(toset), - wrap(&BuilderM), wrap(mask))); - return toset; -} - -llvm::FastMathFlags getFast() { - llvm::FastMathFlags f; - if (EnzymeFastMath) - f.set(); - return f; -} - -void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty, - llvm::SmallVectorImpl &cacheValues, - llvm::IRBuilder<> &BuilderZ, const Twine &name) { - if (!cache_arg) - return; - if (!arg->getType()->isPointerTy()) { - assert(arg->getType() == ty); - cacheValues.push_back(arg); - return; - } -#if LLVM_VERSION_MAJOR < 17 - auto PT = cast(arg->getType()); -#if LLVM_VERSION_MAJOR <= 14 - if (PT->getElementType() != ty) - arg = BuilderZ.CreatePointerCast( - arg, PointerType::get(ty, PT->getAddressSpace()), "pcld." + name); -#else - auto PT2 = PointerType::get(ty, PT->getAddressSpace()); - if (!PT->isOpaqueOrPointeeTypeMatches(PT2)) - arg = BuilderZ.CreatePointerCast( - arg, PointerType::get(ty, PT->getAddressSpace()), "pcld." + name); -#endif -#endif - arg = BuilderZ.CreateLoad(ty, arg, "avld." + name); - cacheValues.push_back(arg); -} - -// julia_decl null means not julia decl, otherwise it is the integer type needed -// to cast to -llvm::Value *to_blas_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, - bool cublas, IntegerType *julia_decl, - IRBuilder<> &entryBuilder, - llvm::Twine const &name) { - if (!byRef) - return V; - - Value *allocV = - entryBuilder.CreateAlloca(V->getType(), nullptr, "byref." + name); - B.CreateStore(V, allocV); - - if (julia_decl) - allocV = B.CreatePointerCast(allocV, getInt8PtrTy(V->getContext()), - "intcast." + name); - - return allocV; -} -llvm::Value *to_blas_fp_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, - Type *fpTy, IRBuilder<> &entryBuilder, - llvm::Twine const &name) { - if (!byRef) - return V; - - Value *allocV = - entryBuilder.CreateAlloca(V->getType(), nullptr, "byref." + name); - B.CreateStore(V, allocV); - - if (fpTy) - allocV = B.CreatePointerCast(allocV, fpTy, "fpcast." + name); - - return allocV; -} - -Value *is_lower(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas) { - if (cublas) { - Value *isNormal = nullptr; - isNormal = B.CreateICmpEQ( - uplo, ConstantInt::get(uplo->getType(), - /*cublasFillMode_t::CUBLAS_FILL_MODE_LOWER*/ 0)); - return isNormal; - } - if (auto CI = dyn_cast(uplo)) { - if (CI->getValue() == 'L' || CI->getValue() == 'l') - return ConstantInt::getTrue(B.getContext()); - if (CI->getValue() == 'U' || CI->getValue() == 'u') - return ConstantInt::getFalse(B.getContext()); - } - if (byRef) { - // can't inspect opaque ptr, so assume 8 (Julia) - IntegerType *charTy = IntegerType::get(uplo->getContext(), 8); - uplo = B.CreateLoad(charTy, uplo, "loaded.trans"); - - auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'L')); - auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'l')); - // fortran blas - return B.CreateOr(isl, isL); - } else { - // we can inspect scalars - auto capi = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 122)); - // TODO we really should just return capi, but for sake of consistency, - // we will accept either here. - auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'L')); - auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'l')); - return B.CreateOr(capi, B.CreateOr(isl, isL)); - } -} - -Value *is_nonunit(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas) { - if (cublas) { - Value *isNormal = nullptr; - isNormal = - B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), - /*CUBLAS_DIAG_NON_UNIT*/ 0)); - return isNormal; - } - if (auto CI = dyn_cast(uplo)) { - if (CI->getValue() == 'N' || CI->getValue() == 'n') - return ConstantInt::getTrue(B.getContext()); - if (CI->getValue() == 'U' || CI->getValue() == 'u') - return ConstantInt::getFalse(B.getContext()); - } - if (byRef) { - // can't inspect opaque ptr, so assume 8 (Julia) - IntegerType *charTy = IntegerType::get(uplo->getContext(), 8); - uplo = B.CreateLoad(charTy, uplo, "loaded.nonunit"); - - auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'N')); - auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'n')); - // fortran blas - return B.CreateOr(isl, isL); - } else { - // we can inspect scalars - auto capi = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 131)); - // TODO we really should just return capi, but for sake of consistency, - // we will accept either here. - auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'N')); - auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'n')); - return B.CreateOr(capi, B.CreateOr(isl, isL)); - } -} - -llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef, - bool cublas) { - if (cublas) { - Value *isNormal = nullptr; - isNormal = B.CreateICmpEQ( - trans, ConstantInt::get(trans->getType(), - /*cublasOperation_t::CUBLAS_OP_N*/ 0)); - return isNormal; - } - // Explicitly support 'N' always, since we use in the rule infra - if (auto CI = dyn_cast(trans)) { - if (CI->getValue() == 'N' || CI->getValue() == 'n') - return ConstantInt::getTrue( - B.getContext()); //(Type::getInt1Ty(B.getContext()), true); - } - if (byRef) { - // can't inspect opaque ptr, so assume 8 (Julia) - IntegerType *charTy = IntegerType::get(trans->getContext(), 8); - trans = B.CreateLoad(charTy, trans, "loaded.trans"); - - auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')); - auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')); - // fortran blas - return B.CreateOr(isn, isN); - } else { - // TODO we really should just return capi, but for sake of consistency, - // we will accept either here. - // we can inspect scalars - auto capi = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111)); - auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')); - auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')); - // fortran blas - return B.CreateOr(capi, B.CreateOr(isn, isN)); - } -} - -llvm::Value *is_left(IRBuilder<> &B, llvm::Value *side, bool byRef, - bool cublas) { - if (cublas) { - Value *isNormal = nullptr; - isNormal = B.CreateICmpEQ( - side, ConstantInt::get(side->getType(), - /*cublasSideMode_t::CUBLAS_SIDE_LEFT*/ 0)); - return isNormal; - } - // Explicitly support 'L'/'R' always, since we use in the rule infra - if (auto CI = dyn_cast(side)) { - if (CI->getValue() == 'L' || CI->getValue() == 'l') - return ConstantInt::getTrue(B.getContext()); - if (CI->getValue() == 'R' || CI->getValue() == 'r') - return ConstantInt::getFalse(B.getContext()); - } - if (byRef) { - // can't inspect opaque ptr, so assume 8 (Julia) - IntegerType *charTy = IntegerType::get(side->getContext(), 8); - side = B.CreateLoad(charTy, side, "loaded.side"); - - auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'L')); - auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'l')); - // fortran blas - return B.CreateOr(isl, isL); - } else { - // TODO we really should just return capi, but for sake of consistency, - // we will accept either here. - // we can inspect scalars - auto capi = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 141)); - auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'L')); - auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'l')); - // fortran blas - return B.CreateOr(capi, B.CreateOr(isl, isL)); - } -} - -// Ok. Here we are. -// netlib declares trans args as something out of -// N,n,T,t,C,c, represented as 8 bit chars. -// However, if we ask openBlas c ABI, -// it is one of the following 32 bit integers values: -// enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; -llvm::Value *transpose(std::string floatType, IRBuilder<> &B, llvm::Value *V, - bool cublas) { - llvm::Type *T = V->getType(); - if (cublas) { - auto isT1 = B.CreateICmpEQ(V, ConstantInt::get(T, 1)); - auto isT0 = B.CreateICmpEQ(V, ConstantInt::get(T, 0)); - return B.CreateSelect(isT1, ConstantInt::get(V->getType(), 0), - B.CreateSelect(isT0, - ConstantInt::get(V->getType(), 1), - ConstantInt::get(V->getType(), 42))); - } else if (T->isIntegerTy(8)) { - if (floatType == "z" || floatType == "c") { - auto isn = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'n')); - auto sel1 = B.CreateSelect(isn, ConstantInt::get(V->getType(), 'c'), - ConstantInt::get(V->getType(), 0)); - - auto isN = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'N')); - auto sel2 = - B.CreateSelect(isN, ConstantInt::get(V->getType(), 'C'), sel1); - - auto ist = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'c')); - auto sel3 = - B.CreateSelect(ist, ConstantInt::get(V->getType(), 'n'), sel2); - - auto isT = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'C')); - return B.CreateSelect(isT, ConstantInt::get(V->getType(), 'N'), sel3); - } else { - // the base case here of 'C' or 'c' becomes simply 'N' - auto isn = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'n')); - auto sel1 = B.CreateSelect(isn, ConstantInt::get(V->getType(), 't'), - ConstantInt::get(V->getType(), 'N')); - - auto isN = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'N')); - auto sel2 = - B.CreateSelect(isN, ConstantInt::get(V->getType(), 'T'), sel1); - - auto ist = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 't')); - auto sel3 = - B.CreateSelect(ist, ConstantInt::get(V->getType(), 'n'), sel2); - - auto isT = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'T')); - return B.CreateSelect(isT, ConstantInt::get(V->getType(), 'N'), sel3); - } - - } else if (T->isIntegerTy(32)) { - auto is111 = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111)); - auto sel1 = B.CreateSelect( - B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 112)), - ConstantInt::get(V->getType(), 111), ConstantInt::get(V->getType(), 0)); - return B.CreateSelect(is111, ConstantInt::get(V->getType(), 112), sel1); - } else { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "cannot handle unknown trans blas value\n" << V; - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), nullptr, ErrorType::NoDerivative, - nullptr, nullptr, nullptr); - } else { - EmitFailure("unknown trans blas value", B.getCurrentDebugLocation(), - B.GetInsertBlock()->getParent(), ss.str()); - } - return V; - } -} - -// Implement the following logic to get the width of a matrix -// if (cache_A) { -// ld_A = (arg_transa == 'N') ? arg_k : arg_m; -// } else { -// ld_A = arg_lda; -// } -llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, - llvm::ArrayRef trans, - llvm::Value *arg_ld, llvm::Value *dim1, - llvm::Value *dim2, bool cacheMat, bool byRef, - bool cublas) { - if (!cacheMat) - return arg_ld; - - assert(trans.size() == 1); - - llvm::Value *width = - CreateSelect(B, is_normal(B, trans[0], byRef, cublas), dim2, dim1); - - return width; -} - -llvm::Value *transpose(std::string floatType, llvm::IRBuilder<> &B, - llvm::Value *V, bool byRef, bool cublas, - llvm::IntegerType *julia_decl, - llvm::IRBuilder<> &entryBuilder, - const llvm::Twine &name) { - - if (!byRef) { - // Explicitly support 'N' always, since we use in the rule infra - if (auto CI = dyn_cast(V)) { - if (floatType == "c" || floatType == "z") { - if (CI->getValue() == 'N') - return ConstantInt::get(CI->getType(), 'C'); - if (CI->getValue() == 'c') - return ConstantInt::get(CI->getType(), 'c'); - } else { - if (CI->getValue() == 'N') - return ConstantInt::get(CI->getType(), 'T'); - if (CI->getValue() == 'n') - return ConstantInt::get(CI->getType(), 't'); - } - } - - // cblas - if (!cublas) - return B.CreateSelect( - B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111)), - ConstantInt::get(V->getType(), 112), - ConstantInt::get(V->getType(), 111)); - } - - if (byRef) { - auto charType = IntegerType::get(V->getContext(), 8); - V = B.CreateLoad(charType, V, "ld." + name); - } - - V = transpose(floatType, B, V, cublas); - - return to_blas_callconv(B, V, byRef, cublas, julia_decl, entryBuilder, - "transpose." + name); -} - -llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType, - llvm::Value *V, bool byRef) { - if (!byRef) - return V; - - if (V->getType()->isIntegerTy()) - V = B.CreateIntToPtr(V, PointerType::getUnqual(intType)); - else - V = B.CreatePointerCast( - V, PointerType::get( - intType, cast(V->getType())->getAddressSpace())); - return B.CreateLoad(intType, V); -} - -SmallVector get_blas_row(llvm::IRBuilder<> &B, - ArrayRef transA, - bool byRef, bool cublas) { - assert(transA.size() == 1); - auto trans = transA[0]; - if (byRef) { - auto charType = IntegerType::get(trans->getContext(), 8); - trans = B.CreateLoad(charType, trans, "ld.row.trans"); - } - - Value *cond = nullptr; - if (!cublas) { - - if (!byRef) { - cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111)); - } else { - auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')); - auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')); - cond = B.CreateOr(isN, isn); - } - } else { - // CUBLAS_OP_N = 0, CUBLAS_OP_T = 1, CUBLAS_OP_C = 2 - // TODO: verify - cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0)); - } - return {cond}; -} -SmallVector get_blas_row(llvm::IRBuilder<> &B, - ArrayRef transA, - ArrayRef row, - ArrayRef col, - bool byRef, bool cublas) { - auto conds = get_blas_row(B, transA, byRef, cublas); - assert(row.size() == col.size()); - SmallVector toreturn; - for (size_t i = 0; i < row.size(); i++) { - auto lhs = row[i]; - auto rhs = col[i]; - if (lhs->getType() != rhs->getType()) - rhs = B.CreatePointerCast(rhs, lhs->getType()); - toreturn.push_back(B.CreateSelect(conds[0], lhs, rhs)); - } - return toreturn; -} - -// return how many Special pointers are in T (count > 0), -// and if there is anything else in T (all == false) -CountTrackedPointers::CountTrackedPointers(Type *T) { - if (isa(T)) { - if (isSpecialPtr(T)) { - count++; - if (T->getPointerAddressSpace() != AddressSpace::Tracked) - derived = true; - } - } else if (isa(T) || isa(T) || isa(T)) { - for (Type *ElT : T->subtypes()) { - auto sub = CountTrackedPointers(ElT); - count += sub.count; - all &= sub.all; - derived |= sub.derived; - } - if (isa(T)) - count *= cast(T)->getNumElements(); - else if (isa(T)) { -#if LLVM_VERSION_MAJOR >= 12 - count *= cast(T)->getElementCount().getKnownMinValue(); -#else - count *= cast(T)->getNumElements(); -#endif - } - } - if (count == 0) - all = false; -} - -#if LLVM_VERSION_MAJOR >= 20 -bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, - SmallMapVector &VariableOffsets, - APInt &ConstantOffset) -#else -bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, - MapVector &VariableOffsets, - APInt &ConstantOffset) -#endif -{ -#if LLVM_VERSION_MAJOR >= 13 - return gep->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset); -#else - assert(BitWidth == DL.getIndexSizeInBits(gep->getPointerAddressSpace()) && - "The offset bit width does not match DL specification."); - - auto CollectConstantOffset = [&](APInt Index, uint64_t Size) { - Index = Index.sextOrTrunc(BitWidth); - APInt IndexedSize = APInt(BitWidth, Size); - ConstantOffset += Index * IndexedSize; - }; - - for (gep_type_iterator GTI = gep_type_begin(gep), GTE = gep_type_end(gep); - GTI != GTE; ++GTI) { - // Scalable vectors are multiplied by a runtime constant. - bool ScalableType = isa(GTI.getIndexedType()); - - Value *V = GTI.getOperand(); - StructType *STy = GTI.getStructTypeOrNull(); - // Handle ConstantInt if possible. - if (auto ConstOffset = dyn_cast(V)) { - if (ConstOffset->isZero()) - continue; - // If the type is scalable and the constant is not zero (vscale * n * 0 = - // 0) bailout. - // TODO: If the runtime value is accessible at any point before DWARF - // emission, then we could potentially keep a forward reference to it - // in the debug value to be filled in later. - if (ScalableType) - return false; - // Handle a struct index, which adds its field offset to the pointer. - if (STy) { - unsigned ElementIdx = ConstOffset->getZExtValue(); - const StructLayout *SL = DL.getStructLayout(STy); - // Element offset is in bytes. - CollectConstantOffset(APInt(BitWidth, SL->getElementOffset(ElementIdx)), - 1); - continue; - } - CollectConstantOffset(ConstOffset->getValue(), - DL.getTypeAllocSize(GTI.getIndexedType())); - continue; - } - - if (STy || ScalableType) - return false; - APInt IndexedSize = - APInt(BitWidth, DL.getTypeAllocSize(GTI.getIndexedType())); - // Insert an initial offset of 0 for V iff none exists already, then - // increment the offset by IndexedSize. - if (IndexedSize != 0) { - VariableOffsets.insert({V, APInt(BitWidth, 0)}); - VariableOffsets[V] += IndexedSize; - } - } - return true; -#endif -} - -llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B, - llvm::Intrinsic::ID ID, llvm::Type *RetTy, - llvm::ArrayRef Args, - llvm::Instruction *FMFSource, - const llvm::Twine &Name) { -#if LLVM_VERSION_MAJOR >= 16 - llvm::CallInst *nres = B.CreateIntrinsic(RetTy, ID, Args, FMFSource, Name); -#else - SmallVector Table; - Intrinsic::getIntrinsicInfoTableEntries(ID, Table); - ArrayRef TableRef(Table); - - SmallVector ArgTys; - ArgTys.reserve(Args.size()); - for (auto &I : Args) - ArgTys.push_back(I->getType()); - FunctionType *FTy = FunctionType::get(RetTy, ArgTys, false); - SmallVector OverloadTys; - Intrinsic::MatchIntrinsicTypesResult Res = - matchIntrinsicSignature(FTy, TableRef, OverloadTys); - (void)Res; - assert(Res == Intrinsic::MatchIntrinsicTypes_Match && TableRef.empty() && - "Wrong types for intrinsic!"); - Function *Fn = Intrinsic::getDeclaration(B.GetInsertPoint()->getModule(), ID, - OverloadTys); - CallInst *nres = B.CreateCall(Fn, Args, {}, Name); - if (FMFSource) - nres->copyFastMathFlags(FMFSource); -#endif - return nres; -} - -/* Bithack to compute 1 ulp as follows: -double ulp(double res) { - double nres = res; - (*(uint64_t*)&nres) = 0x1 ^ *(uint64_t*)&nres; - return abs(nres - res); -} -*/ -llvm::Value *get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res) { - auto ty = res->getType(); - unsigned tsize = builder.GetInsertBlock() - ->getParent() - ->getParent() - ->getDataLayout() - .getTypeSizeInBits(ty); - - auto ity = IntegerType::get(ty->getContext(), tsize); - - auto as_int = builder.CreateBitCast(res, ity); - auto masked = builder.CreateXor(as_int, ConstantInt::get(ity, 1)); - auto neighbor = builder.CreateBitCast(masked, ty); - - auto diff = builder.CreateFSub(res, neighbor); - - auto absres = builder.CreateIntrinsic(Intrinsic::fabs, - ArrayRef(diff->getType()), - ArrayRef(diff)); - - return absres; -} - -llvm::Value *EmitNoDerivativeError(const std::string &message, - llvm::Instruction &inst, - GradientUtils *gutils, - llvm::IRBuilder<> &Builder2, - llvm::Value *condition) { - if (CustomErrorHandler) { - return unwrap(CustomErrorHandler(message.c_str(), wrap(&inst), - ErrorType::NoDerivative, gutils, - wrap(condition), wrap(&Builder2))); - } else if (EnzymeRuntimeError) { - auto &M = *inst.getParent()->getParent()->getParent(); - FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()), - {getInt8PtrTy(M.getContext())}, false); - auto msg = getString(M, message); - auto PutsF = M.getOrInsertFunction("puts", FT); - Builder2.CreateCall(PutsF, msg); - - FunctionType *FT2 = - FunctionType::get(Type::getVoidTy(M.getContext()), - {Type::getInt32Ty(M.getContext())}, false); - - auto ExitF = M.getOrInsertFunction("exit", FT2); - Builder2.CreateCall(ExitF, - ConstantInt::get(Type::getInt32Ty(M.getContext()), 1)); - return nullptr; - } else { - if (StringRef(message).contains("cannot handle above cast")) { - gutils->TR.dump(); - } - EmitFailure("NoDerivative", inst.getDebugLoc(), &inst, message); - return nullptr; - } -} - -bool EmitNoDerivativeError(const std::string &message, Value *todiff, - RequestContext &context) { - Value *toshow = todiff; - if (context.req) { - toshow = context.req; - } - if (CustomErrorHandler) { - CustomErrorHandler(message.c_str(), wrap(toshow), ErrorType::NoDerivative, - nullptr, wrap(todiff), wrap(context.ip)); - return true; - } else if (context.ip && EnzymeRuntimeError) { - auto &M = *context.ip->GetInsertBlock()->getParent()->getParent(); - FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()), - {getInt8PtrTy(M.getContext())}, false); - auto msg = getString(M, message); - auto PutsF = M.getOrInsertFunction("puts", FT); - context.ip->CreateCall(PutsF, msg); - - FunctionType *FT2 = - FunctionType::get(Type::getVoidTy(M.getContext()), - {Type::getInt32Ty(M.getContext())}, false); - - auto ExitF = M.getOrInsertFunction("exit", FT2); - context.ip->CreateCall( - ExitF, ConstantInt::get(Type::getInt32Ty(M.getContext()), 1)); - return true; - } else if (context.req) { - EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, - message); - return true; - } else if (auto arg = dyn_cast(todiff)) { - auto loc = arg->getDebugLoc(); - EmitFailure("NoDerivative", loc, arg, message); - return true; - } - return false; -} - -void EmitNoTypeError(const std::string &message, llvm::Instruction &inst, - GradientUtils *gutils, llvm::IRBuilder<> &Builder2) { - if (CustomErrorHandler) { - CustomErrorHandler(message.c_str(), wrap(&inst), ErrorType::NoType, - gutils->TR.analyzer, nullptr, wrap(&Builder2)); - } else if (EnzymeRuntimeError) { - auto &M = *inst.getParent()->getParent()->getParent(); - FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()), - {getInt8PtrTy(M.getContext())}, false); - auto msg = getString(M, message); - auto PutsF = M.getOrInsertFunction("puts", FT); - Builder2.CreateCall(PutsF, msg); - - FunctionType *FT2 = - FunctionType::get(Type::getVoidTy(M.getContext()), - {Type::getInt32Ty(M.getContext())}, false); - - auto ExitF = M.getOrInsertFunction("exit", FT2); - Builder2.CreateCall(ExitF, - ConstantInt::get(Type::getInt32Ty(M.getContext()), 1)); - } else { - std::string str; - raw_string_ostream ss(str); - ss << message << "\n"; - gutils->TR.dump(ss); - EmitFailure("CannotDeduceType", inst.getDebugLoc(), &inst, ss.str()); - } -} - -std::vector> -parseTrueType(const llvm::MDNode *md, DerivativeMode Mode, bool const_src) { - std::vector> parsed; - for (size_t i = 0; i < md->getNumOperands(); i += 2) { - ConcreteType base( - llvm::cast(md->getOperand(i))->getString(), - md->getContext()); - auto size = llvm::cast( - llvm::cast(md->getOperand(i + 1)) - ->getValue()) - ->getSExtValue(); - parsed.emplace_back(base, size); - } - - std::vector> toIterate; - size_t idx = 0; - while (idx < parsed.size()) { - - auto dt = parsed[idx].first; - size_t start = parsed[idx].second; - size_t end = 0x0fffffff; - for (idx = idx + 1; idx < parsed.size(); ++idx) { - bool Legal = true; - auto tmp = dt; - auto next = parsed[idx].first; - tmp.checkedOrIn(next, /*PointerIntSame*/ true, Legal); - // Prevent fusion of {Anything, Float} since anything is an int rule - // but float requires zeroing. - if ((dt == BaseType::Anything && - (next != BaseType::Anything && next.isKnown())) || - (next == BaseType::Anything && - (dt != BaseType::Anything && dt.isKnown()))) - Legal = false; - if (!Legal) { - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - // if both are floats (of any type), forward mode is the same. - // + [potentially zero if const, otherwise copy] - // if both are int/pointer (of any type), also the same - // + copy - // if known non-constant, also the same - // + copy - if ((parsed[idx].first.isFloat() == nullptr) == - (parsed[idx - 1].first.isFloat() == nullptr)) { - Legal = true; - } - if (const_src) { - Legal = true; - } - } - if (!Legal) { - end = parsed[idx].second; - break; - } - } else - dt = tmp; - } - assert(dt.isKnown()); - toIterate.emplace_back(dt.isFloat(), start, end - start); - } - return toIterate; -} - -void dumpModule(llvm::Module *mod) { llvm::errs() << *mod << "\n"; } - -void dumpValue(llvm::Value *val) { llvm::errs() << *val << "\n"; } - -void dumpBlock(llvm::BasicBlock *blk) { llvm::errs() << *blk << "\n"; } - -void dumpType(llvm::Type *ty) { llvm::errs() << *ty << "\n"; } - -void dumpTypeResults(TypeResults &TR) { TR.dump(); } - -bool isNVLoad(const llvm::Value *V) { - auto II = dyn_cast(V); - if (!II) - return false; - switch (II->getIntrinsicID()) { - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: -#if LLVM_VERSION_MAJOR < 20 - case Intrinsic::nvvm_ldg_global_i: - case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: -#endif - return true; - default: - return false; - } - return false; -} - -bool notCapturedBefore(llvm::Value *V, Instruction *inst, - size_t checkLoadCaptures) { - Instruction *VI = dyn_cast(V); - if (!VI) - VI = &*inst->getParent()->getParent()->getEntryBlock().begin(); - else - VI = VI->getNextNode(); - SmallPtrSet regionBetween; - { - SmallVector todo; - todo.push_back(VI->getParent()); - while (todo.size()) { - auto cur = todo.pop_back_val(); - if (regionBetween.count(cur)) - continue; - regionBetween.insert(cur); - if (cur == inst->getParent()) - continue; - for (auto BB : successors(cur)) - todo.push_back(BB); - } - } - SmallVector, 1> todo; - for (auto U : V->users()) { - todo.emplace_back(cast(U), checkLoadCaptures, V); - } - std::set> seen; - while (todo.size()) { - auto pair = todo.pop_back_val(); - if (seen.count(pair)) - continue; - auto UI = std::get<0>(pair); - auto level = std::get<1>(pair); - auto prev = std::get<2>(pair); - if (!regionBetween.count(UI->getParent())) - continue; - if (UI->getParent() == VI->getParent()) { - if (UI->comesBefore(VI)) - continue; - } - if (UI->getParent() == inst->getParent()) - if (inst->comesBefore(UI)) - continue; - - if (isPointerArithmeticInst(UI, /*includephi*/ true, - /*includebin*/ true)) { - for (auto U2 : UI->users()) { - auto UI2 = cast(U2); - todo.emplace_back(UI2, level, UI); - } - continue; - } - - if (isa(UI)) - continue; - - if (isa(UI)) { - if (level == 0) - continue; - if (UI->getOperand(1) != prev) - continue; - } - - if (auto CI = dyn_cast(UI)) { -#if LLVM_VERSION_MAJOR >= 14 - for (size_t i = 0, size = CI->arg_size(); i < size; i++) -#else - for (size_t i = 0, size = CI->getNumArgOperands(); i < size; i++) -#endif - { - if (prev == CI->getArgOperand(i)) { - if (isNoCapture(CI, i) && level == 0) - continue; - return false; - } - } - return true; - } - - if (isa(UI)) { - continue; - } - if (isa(UI)) { - if (level) { - for (auto U2 : UI->users()) { - auto UI2 = cast(U2); - todo.emplace_back(UI2, level - 1, UI); - } - } - continue; - } - // storing into it. - if (auto SI = dyn_cast(UI)) { - if (SI->getValueOperand() != prev) { - continue; - } - } - return false; - } - return true; -} - -// Return true if guaranteed not to alias -// Return false if guaranteed to alias [with possible offset depending on flag]. -// Return {} if no information is given. -#if LLVM_VERSION_MAJOR >= 16 -std::optional -#else -llvm::Optional -#endif -arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA, - llvm::LoopInfo &LI, llvm::Value *op0, - llvm::Value *op1, bool offsetAllowed) { - auto lhs = getBaseObject(op0, offsetAllowed); - auto rhs = getBaseObject(op1, offsetAllowed); - - if (lhs == rhs) { - return false; - } - if (!lhs->getType()->isPointerTy() && !rhs->getType()->isPointerTy()) - return {}; - - bool noalias_lhs = isNoAlias(lhs); - bool noalias_rhs = isNoAlias(rhs); - - bool noalias[2] = {noalias_lhs, noalias_rhs}; - - for (int i = 0; i < 2; i++) { - Value *start = (i == 0) ? lhs : rhs; - Value *end = (i == 0) ? rhs : lhs; - if (noalias[i]) { - if (noalias[1 - i]) { - return true; - } - if (isa(end)) { - return true; - } - if (auto endi = dyn_cast(end)) { - if (notCapturedBefore(start, endi, 0)) { - return true; - } - } - } - if (auto ld = dyn_cast(start)) { - auto base = getBaseObject(ld->getOperand(0), /*offsetAllowed*/ false); - if (isAllocationCall(base, TLI)) { - if (isa(end)) - return true; - if (auto endi = dyn_cast(end)) - if (isNoAlias(end) || (notCapturedBefore(start, endi, 1))) { - Instruction *starti = dyn_cast(start); - if (!starti) { - if (!isa(start)) - continue; - starti = - &cast(start)->getParent()->getEntryBlock().front(); - } - - bool overwritten = false; - allInstructionsBetween( - LI, starti, endi, [&](Instruction *I) -> bool { - if (!I->mayWriteToMemory()) - return /*earlyBreak*/ false; - - if (writesToMemoryReadBy(nullptr, AA, TLI, - /*maybeReader*/ ld, - /*maybeWriter*/ I)) { - overwritten = true; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - - if (!overwritten) { - return true; - } - } - } - } - } - - return {}; -} - -bool isTargetNVPTX(llvm::Module &M) { -#if LLVM_VERSION_MAJOR > 20 - return M.getTargetTriple().getArch() == Triple::ArchType::nvptx || - M.getTargetTriple().getArch() == Triple::ArchType::nvptx64; -#else - return M.getTargetTriple().find("nvptx") != std::string::npos; -#endif -} diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 8538f55b1916..c5a8ca2d6415 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -66,8 +66,6 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" -#include "TypeAnalysis/ConcreteType.h" - class TypeResults; namespace llvm { @@ -907,54 +905,6 @@ allUnsyncdPredecessorsOf(llvm::Instruction *inst, llvm::function_ref f, llvm::function_ref preEntry) { - for (auto uinst = inst->getPrevNode(); uinst != nullptr; - uinst = uinst->getPrevNode()) { - if (auto II = llvm::dyn_cast(uinst)) { - if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0 || - II->getIntrinsicID() == llvm::Intrinsic::amdgcn_s_barrier) { - return; - } - } - if (f(uinst)) - return; - } - - std::deque todo; - std::set done; - for (auto suc : llvm::predecessors(inst->getParent())) { - todo.push_back(suc); - } - while (todo.size()) { - auto BB = todo.front(); - todo.pop_front(); - if (done.count(BB)) - continue; - done.insert(BB); - - bool syncd = false; - llvm::BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend(); - for (; I != E; ++I) { - if (auto II = llvm::dyn_cast(&*I)) { - if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0 || - II->getIntrinsicID() == llvm::Intrinsic::amdgcn_s_barrier) { - syncd = true; - break; - } - } - if (f(&*I)) - return; - if (&*I == inst) - break; - } - if (!syncd) { - for (auto suc : llvm::predecessors(BB)) { - todo.push_back(suc); - } - if (&BB->getParent()->getEntryBlock() == BB) { - preEntry(); - } - } - } } #include "llvm/Analysis/LoopInfo.h" @@ -1108,10 +1058,6 @@ static inline llvm::Value *getMPIMemberPtr(llvm::IRBuilder<> &B, llvm::Value *V, } } -llvm::Value *getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr, - llvm::Type *OpType, ConcreteType CT, - llvm::Type *intType, llvm::IRBuilder<> &B2); - class AssertingReplacingVH final : public llvm::CallbackVH { public: AssertingReplacingVH() = default; @@ -1681,88 +1627,6 @@ static inline bool isNoAlias(const llvm::Value *val) { } static inline bool isNoEscapingAllocation(const llvm::Function *F) { - if (F->hasFnAttribute("enzyme_no_escaping_allocation")) - return true; - using namespace llvm; - switch (F->getIntrinsicID()) { - case Intrinsic::memset: - case Intrinsic::memcpy: - case Intrinsic::memmove: -#if LLVM_VERSION_MAJOR >= 12 - case Intrinsic::experimental_noalias_scope_decl: -#endif - case Intrinsic::objectsize: - case Intrinsic::floor: - case Intrinsic::ceil: - case Intrinsic::trunc: - case Intrinsic::rint: - case Intrinsic::lrint: - case Intrinsic::llrint: - case Intrinsic::nearbyint: - case Intrinsic::round: - case Intrinsic::roundeven: - case Intrinsic::lround: - case Intrinsic::llround: - case Intrinsic::nvvm_barrier0: - case Intrinsic::nvvm_barrier0_popc: - case Intrinsic::nvvm_barrier0_and: - case Intrinsic::nvvm_barrier0_or: - case Intrinsic::nvvm_membar_cta: - case Intrinsic::nvvm_membar_gl: - case Intrinsic::nvvm_membar_sys: - case Intrinsic::amdgcn_s_barrier: - case Intrinsic::assume: - case Intrinsic::lifetime_start: - case Intrinsic::lifetime_end: -#if LLVM_VERSION_MAJOR <= 16 - case Intrinsic::dbg_addr: -#endif - - case Intrinsic::dbg_declare: - case Intrinsic::dbg_value: - case Intrinsic::dbg_label: - case Intrinsic::invariant_start: - case Intrinsic::invariant_end: - case Intrinsic::var_annotation: - case Intrinsic::ptr_annotation: - case Intrinsic::annotation: - case Intrinsic::codeview_annotation: - case Intrinsic::expect: - case Intrinsic::type_test: - case Intrinsic::donothing: - case Intrinsic::prefetch: - case Intrinsic::trap: - case Intrinsic::is_constant: -#if LLVM_VERSION_MAJOR >= 12 - case Intrinsic::smax: - case Intrinsic::smin: - case Intrinsic::umax: - case Intrinsic::umin: -#endif - case Intrinsic::ctlz: - case Intrinsic::cttz: - case Intrinsic::sadd_with_overflow: - case Intrinsic::ssub_with_overflow: -#if LLVM_VERSION_MAJOR >= 12 - case Intrinsic::abs: -#endif - case Intrinsic::sqrt: - case Intrinsic::exp: - case Intrinsic::cos: - case Intrinsic::sin: -#if LLVM_VERSION_MAJOR >= 19 - case Intrinsic::tanh: - case Intrinsic::cosh: - case Intrinsic::sinh: -#endif - case Intrinsic::copysign: - case Intrinsic::fabs: - return true; - default: - break; - } - // if (F->empty()) - // llvm::errs() << " may escape:" << F->getName() << "\n"; return false; } static inline bool isNoEscapingAllocation(const llvm::CallBase *call) { From 5f9f95fd513bcc7e1ea8064b9d38feb48a1ce1de Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 9 Jun 2025 23:04:05 -0500 Subject: [PATCH 03/13] starting dlopen --- enzyme/Enzyme/Enzyme.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index a2e981fa27a2..d572cf1213a5 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -27,6 +27,7 @@ #include #include "llvm/ADT/StringRef.h" + #include #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -88,6 +89,10 @@ using namespace llvm; #endif #define DEBUG_TYPE "lower-reactant-intrinsic" +llvm::cl::opt + Passes("raising-plugin-path", cl::init(""), cl::Hidden, + cl::desc("Print before and after fns for autodiff")); + namespace { constexpr char cudaLaunchSymbolName[] = "cudaLaunchKernel"; @@ -426,7 +431,21 @@ class ReactantBase { GlobalOptPass().run(M, MAM); } + llvm::errs() << "M: " << M << "\n"; + + auto lib = dlopen(Passes.c_str(), RTLD_LAZY | RTLD_DEEPBIND); + auto sym = dlsym(lib, "runLLVMToMLIRRoundTrip"); + + auto runLLVMToMLIRRoundTrip = (std::string (*)(std::string))sym; + if (runLLVMToMLIRRoundTrip) { + std::string MStr; + llvm::raw_string_ostream ss(MStr); + ss << M; + auto newMod = runLLVMToMLIRRoundTrip(MStr); + llvm::errs() << " newMod: " << newMod << "\n"; + } + return changed; } }; From a20a1ba42668e30dd48f2214973df20ec9d95433 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 10 Jun 2025 13:14:55 -0500 Subject: [PATCH 04/13] ctors --- enzyme/Enzyme/Enzyme.cpp | 56 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index d572cf1213a5..60004943ef5d 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -126,6 +126,7 @@ void fixup(Module &M) { auto GridDim2 = CI->getArgOperand(2); auto BlockDim1 = CI->getArgOperand(3); auto BlockDim2 = CI->getArgOperand(4); + auto ArgPtr = CI->getArgOperand(5); auto SharedMemSize = CI->getArgOperand(6); auto StreamPtr = CI->getArgOperand(7); SmallVector Args = { @@ -133,8 +134,15 @@ void fixup(Module &M) { BlockDim2, SharedMemSize, StreamPtr, }; auto StubFunc = cast(CI->getArgOperand(0)); - for (auto &Arg : StubFunc->args()) - Args.push_back(&Arg); + + size_t idx = 0; + for (auto &Arg : StubFunc->args()) { + auto gep = Builder.CreateConstInBoundsGEP1_64(llvm::PointerType::getUnqual(CI->getContext()), ArgPtr, idx); + auto ld = Builder.CreateLoad(llvm::PointerType::getUnqual(CI->getContext()), gep); + ld = Builder.CreateLoad(Arg.getType(), ld); + Args.push_back(ld); + idx++; + } SmallVector ArgTypes; for (Value *V : Args) ArgTypes.push_back(V->getType()); @@ -417,6 +425,50 @@ class ReactantBase { F->eraseFromParent(); } } + + if (auto GV = M.getGlobalVariable("llvm.global_ctors")) { + ConstantArray *CA = dyn_cast(GV->getInitializer()); + if (CA) { + + bool changed = false; + SmallVector newOperands; + for (Use &OP : CA->operands()) { + if (isa(OP)) { + changed = true; + continue; + } + ConstantStruct *CS = cast(OP); + if (isa(CS->getOperand(1))) { + changed = true; + continue; + } + newOperands.push_back(CS); + } + if (changed) { + if (newOperands.size() == 0) { + GV->eraseFromParent(); + } else { + auto EltTy = newOperands[0]->getType(); + ArrayType *NewType = ArrayType::get(EltTy, newOperands.size()); + auto CT = ConstantArray::get(NewType, newOperands); + + // Create the new global variable. + GlobalVariable *NG = new GlobalVariable( + M, NewType, GV->isConstant(), GV->getLinkage(), + /*init*/ CT, /*name*/ "", GV, GV->getThreadLocalMode(), + GV->getAddressSpace()); + + NG->copyAttributesFrom(GV); + NG->takeName(GV); + GV->replaceAllUsesWith(NG); + GV->eraseFromParent(); + + } + + } + } + } + { PassBuilder PB; LoopAnalysisManager LAM; From 94fcc3ee5e0916e35dbb6c948e1c4488c57a71a9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 11 Jun 2025 14:55:29 -0500 Subject: [PATCH 05/13] fix --- enzyme/Enzyme/Enzyme.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 60004943ef5d..41d3e02d0ee7 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -365,6 +365,7 @@ class ReactantBase { if (!F.empty()) F.setLinkage(Function::LinkageTypes::InternalLinkage); } + SmallVector toInternalize; if (auto RF = M.getFunction("__cudaRegisterFunction")) { for (auto U : make_early_inc_range(RF->users())) { if (auto CI = dyn_cast(U)) { @@ -395,8 +396,10 @@ class ReactantBase { F22->deleteBody(); MF->setCallingConv(llvm::CallingConv::C); MF->setLinkage(Function::LinkageTypes::LinkOnceODRLinkage); + toInternalize.push_back(MF->getName().str()); + CI->eraseFromParent(); + llvm::errs() << " replacing: " << nameVal << "\n"; } - CI->eraseFromParent(); } } } @@ -407,6 +410,10 @@ class ReactantBase { Linker L(M); L.linkInModule(std::move(mod2)); M.getContext().setDiagnosticHandler(std::move(handler)); + for (auto name : toInternalize) + if (auto F = M.getFunction(name)) { + F->setLinkage(Function::LinkageTypes::InternalLinkage); + } } for (Function &F : make_early_inc_range(M)) { @@ -487,8 +494,13 @@ class ReactantBase { llvm::errs() << "M: " << M << "\n"; auto lib = dlopen(Passes.c_str(), RTLD_LAZY | RTLD_DEEPBIND); + if (!lib) { + llvm::errs() << " could not open " << Passes.c_str() << " - " << dlerror() << "\n"; + } auto sym = dlsym(lib, "runLLVMToMLIRRoundTrip"); - + if (!sym) { + llvm::errs() << " could not find sym\n"; + } auto runLLVMToMLIRRoundTrip = (std::string (*)(std::string))sym; if (runLLVMToMLIRRoundTrip) { std::string MStr; From b0b82e82028c2ef915e91395d9a0f8c695045021 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 12 Jun 2025 20:39:34 -0500 Subject: [PATCH 06/13] fix --- enzyme/Enzyme/Enzyme.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 41d3e02d0ee7..1ad22e662e60 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -23,6 +23,9 @@ // the function passed as the first argument. // //===----------------------------------------------------------------------===// +#define private public +#include "llvm/IR/Module.h" +#undef private #include #include @@ -508,6 +511,27 @@ class ReactantBase { ss << M; auto newMod = runLLVMToMLIRRoundTrip(MStr); llvm::errs() << " newMod: " << newMod << "\n"; + M.dropAllReferences(); + + M.getGlobalList().clear(); + M.getFunctionList().clear(); + M.getAliasList().clear(); + M.getIFuncList().clear(); + + + llvm::SMDiagnostic Err; + auto llvmModule = + llvm::parseIR(llvm::MemoryBufferRef(newMod, "conversion"), Err, M.getContext()); + + if (!llvmModule) { + Err.print(/*ProgName=*/"LLVMToMLIR", llvm::errs()); + exit(1); + } + //auto handler = M.getContext().getDiagnosticHandler(); + Linker L(M); + L.linkInModule(std::move(llvmModule)); + // M.getContext().setDiagnosticHandler(std::move(handler)); + } return changed; From a0d35e9e78bfd24d26877b0c63a42176798b33e1 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 12 Jun 2025 22:08:12 -0500 Subject: [PATCH 07/13] fix --- enzyme/Enzyme/Enzyme.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 1ad22e662e60..0e21bfc48fdc 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -527,10 +527,10 @@ class ReactantBase { Err.print(/*ProgName=*/"LLVMToMLIR", llvm::errs()); exit(1); } - //auto handler = M.getContext().getDiagnosticHandler(); + auto handler = M.getContext().getDiagnosticHandler(); Linker L(M); L.linkInModule(std::move(llvmModule)); - // M.getContext().setDiagnosticHandler(std::move(handler)); + M.getContext().setDiagnosticHandler(std::move(handler)); } From abd88c2d1dba704b77890c801a09ad02a0d25bbe Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 12 Jun 2025 22:08:22 -0500 Subject: [PATCH 08/13] fmt --- enzyme/Enzyme/Enzyme.cpp | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 0e21bfc48fdc..79e4a3588ca2 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -513,25 +513,23 @@ class ReactantBase { llvm::errs() << " newMod: " << newMod << "\n"; M.dropAllReferences(); - M.getGlobalList().clear(); - M.getFunctionList().clear(); - M.getAliasList().clear(); - M.getIFuncList().clear(); - + M.getGlobalList().clear(); + M.getFunctionList().clear(); + M.getAliasList().clear(); + M.getIFuncList().clear(); - llvm::SMDiagnostic Err; - auto llvmModule = - llvm::parseIR(llvm::MemoryBufferRef(newMod, "conversion"), Err, M.getContext()); + llvm::SMDiagnostic Err; + auto llvmModule = llvm::parseIR( + llvm::MemoryBufferRef(newMod, "conversion"), Err, M.getContext()); if (!llvmModule) { - Err.print(/*ProgName=*/"LLVMToMLIR", llvm::errs()); - exit(1); + Err.print(/*ProgName=*/"LLVMToMLIR", llvm::errs()); + exit(1); } - auto handler = M.getContext().getDiagnosticHandler(); + auto handler = M.getContext().getDiagnosticHandler(); Linker L(M); L.linkInModule(std::move(llvmModule)); M.getContext().setDiagnosticHandler(std::move(handler)); - } return changed; From 1ffa3a2231739249e521a32f72d47f927d5b6d3e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Jun 2025 18:51:54 -0400 Subject: [PATCH 09/13] fewer prints --- enzyme/Enzyme/Enzyme.cpp | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 79e4a3588ca2..8255ed7e78b5 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -304,13 +304,9 @@ class ReactantBase { bool run(Module &M) { bool changed = true; - llvm::errs() << " pre: " << M << "\n"; fixup(M); - llvm::errs() << "M: " << M << "\n"; for (auto bin : gpubins) { - llvm::errs() << " gpubin: " << bin << "\n"; - SMDiagnostic Err; auto mod2 = llvm::parseIRFile(bin + ".re_export", Err, M.getContext()); if (!mod2) { @@ -401,14 +397,11 @@ class ReactantBase { MF->setLinkage(Function::LinkageTypes::LinkOnceODRLinkage); toInternalize.push_back(MF->getName().str()); CI->eraseFromParent(); - llvm::errs() << " replacing: " << nameVal << "\n"; } } } } - llvm::errs() << " mod2: " << *mod2 << "\n"; - auto handler = M.getContext().getDiagnosticHandler(); Linker L(M); L.linkInModule(std::move(mod2)); @@ -494,8 +487,6 @@ class ReactantBase { GlobalOptPass().run(M, MAM); } - llvm::errs() << "M: " << M << "\n"; - auto lib = dlopen(Passes.c_str(), RTLD_LAZY | RTLD_DEEPBIND); if (!lib) { llvm::errs() << " could not open " << Passes.c_str() << " - " << dlerror() << "\n"; @@ -510,7 +501,6 @@ class ReactantBase { llvm::raw_string_ostream ss(MStr); ss << M; auto newMod = runLLVMToMLIRRoundTrip(MStr); - llvm::errs() << " newMod: " << newMod << "\n"; M.dropAllReferences(); M.getGlobalList().clear(); @@ -523,6 +513,7 @@ class ReactantBase { llvm::MemoryBufferRef(newMod, "conversion"), Err, M.getContext()); if (!llvmModule) { + llvm::errs() << " newMod: " << newMod << "\n"; Err.print(/*ProgName=*/"LLVMToMLIR", llvm::errs()); exit(1); } @@ -553,11 +544,9 @@ class ReactantNewPM final : public ReactantBase, public: using Result = llvm::PreservedAnalyses; ReactantNewPM(const std::vector &gpubins) : ReactantBase(gpubins) { - llvm::errs() << " constructing new pm\n"; } Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { -llvm::errs() << " running on module: " << M << "\n"; return ReactantBase::run(M) ? PreservedAnalyses::none() : PreservedAnalyses::all(); } @@ -588,7 +577,6 @@ class ExporterNewPM final : public AnalysisInfoMixin { } file << M; - llvm::errs() << " exported to: " << filename << "\n"; return PreservedAnalyses::all(); } From a79ca88540de1bd8b669e27e665a07fbd4d4439e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 17 Jun 2025 17:27:59 -0500 Subject: [PATCH 10/13] fix --- enzyme/Enzyme/Enzyme.cpp | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 8255ed7e78b5..a9124814dcde 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -237,24 +237,16 @@ void fixup(Module &M) { Load->replaceAllUsesWith(NewLoad); } CoercedKernels.insert(KernelLaunch); - - It = &*PopCall->getParent()->getPrevNode()->getFirstNonPHIOrDbg(); - CallInst *PushCall = dyn_cast(It); - while (!It->isTerminator() && - !(PushCall && PushCall->getCalledFunction() && - PushCall->getCalledFunction()->getName() == cudaPushConfigName)) { - It = It->getNextNonDebugInstruction(); - PushCall = dyn_cast(It); - } - - assert(!It->isTerminator()); - // Replace with success - PushCall->replaceAllUsesWith(IRB.getInt32(0)); - PushCall->eraseFromParent(); PopCall->replaceAllUsesWith(IRB.getInt32(0)); PopCall->eraseFromParent(); } + + for (CallInst *PushCall : gatherCallers(PushConfigFunc)) { + // Replace with success + PushCall->replaceAllUsesWith(ConstantInt::get(IntegerType::get(PushCall->getContext(), 32), 0)); + PushCall->eraseFromParent(); + } for (CallInst *CI : CoercedKernels) { IRBuilder<> Builder(CI); auto FuncPtr = CI->getArgOperand(0); @@ -631,6 +623,16 @@ extern "C" void registerReactant(llvm::PassBuilder &PB, std::vector MPM.addPass(ReactantNewPM(gpubinaries)); }; + PB.registerPipelineParsingCallback( + [=](llvm::StringRef Name, llvm::ModulePassManager &MPM, + llvm::ArrayRef) { + if (Name == "reactant") { + MPM.addPass(ReactantNewPM(gpubinaries)); + return true; + } + return false; + }); + // TODO need for perf reasons to move Enzyme pass to the pre vectorization. PB.registerOptimizerEarlyEPCallback(loadPass); @@ -644,3 +646,13 @@ extern "C" void registerReactant(llvm::PassBuilder &PB, std::vector }; PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO); } + +extern "C" void registerReactant2(llvm::PassBuilder &PB) { + registerReactant(PB, {}); +} + +extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +llvmGetPassPluginInfo() { + return {LLVM_PLUGIN_API_VERSION, "ReactantNewPM", "v0.1", registerReactant2}; +} + From 483f6b2e41d4650624bb25e6820debd5fbd5ca08 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 25 Jun 2025 20:54:36 -0400 Subject: [PATCH 11/13] fix release mode --- enzyme/Enzyme/Enzyme.cpp | 44 ++++++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index a9124814dcde..14d873893ede 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -105,12 +105,12 @@ constexpr char kernelCoercedPrefix[] = "__mlir_launch_coerced_kernel_"; constexpr char cudaPushConfigName[] = "__cudaPushCallConfiguration"; constexpr char cudaPopConfigName[] = "__cudaPopCallConfiguration"; -SmallVector gatherCallers(Function *F) { +SmallVector gatherCallers(Function *F) { if (!F) return {}; - SmallVector ToHandle; + SmallVector ToHandle; for (auto User : F->users()) - if (auto CI = dyn_cast(User)) + if (auto CI = dyn_cast(User)) if (CI->getCalledFunction() == F) ToHandle.push_back(CI); return ToHandle; @@ -121,8 +121,8 @@ void fixup(Module &M) { if (!LaunchKernelFunc) return; - SmallPtrSet CoercedKernels; - for (CallInst *CI : gatherCallers(LaunchKernelFunc)) { + SmallPtrSet CoercedKernels; + for (CallBase *CI : gatherCallers(LaunchKernelFunc)) { IRBuilder<> Builder(CI); auto FuncPtr = CI->getArgOperand(0); auto GridDim1 = CI->getArgOperand(1); @@ -156,14 +156,18 @@ void fixup(Module &M) { kernelCoercedPrefix + StubFunc->getName(), M); CoercedKernels.insert(Builder.CreateCall(MlirLaunchFunc, Args)); + if (auto II = dyn_cast(CI)) { + Builder.CreateBr(II->getNormalDest()); + II->getUnwindDest()->removePredecessor(II->getParent()); + } CI->eraseFromParent(); } SmallVector InlinedStubs; - for (CallInst *CI : CoercedKernels) { + for (CallBase *CI : CoercedKernels) { Function *StubFunc = cast(CI->getArgOperand(0)); for (User *callee : StubFunc->users()) { - if (auto *CI = dyn_cast(callee)) { + if (auto *CI = dyn_cast(callee)) { if (CI->getCalledFunction() == StubFunc) { InlineFunctionInfo IFI; InlineResult Res = @@ -184,7 +188,7 @@ void fixup(Module &M) { CoercedKernels.clear(); DenseMap> FuncAllocas; auto PushConfigFunc = M.getFunction(cudaPushConfigName); - for (CallInst *CI : gatherCallers(PushConfigFunc)) { + for (CallBase *CI : gatherCallers(PushConfigFunc)) { Function *TheFunc = CI->getFunction(); IRBuilder<> IRB(&TheFunc->getEntryBlock(), TheFunc->getEntryBlock().getFirstNonPHIOrDbgOrAlloca()); @@ -202,21 +206,25 @@ void fixup(Module &M) { IRB.CreateAlloca(IRB.getInt64Ty(), nullptr, "shmem_size")); Allocas.push_back(IRB.CreateAlloca(IRB.getPtrTy(), nullptr, "stream")); FuncAllocas.insert_or_assign(TheFunc, Allocas); + llvm::errs() <<" CI: making allocas for " << *CI << "\n"; } IRB.SetInsertPoint(CI); + if (CI->arg_size() != Allocas.size()) { + llvm::errs() << " size mismatch on: " << *CI << "\n"; + } for (auto [Arg, Alloca] : - llvm::zip_equal(llvm::drop_end(CI->operand_values()), Allocas)) + llvm::zip_equal(CI->args(), Allocas)) IRB.CreateStore(Arg, Alloca); } auto PopConfigFunc = M.getFunction(cudaPopConfigName); - for (CallInst *PopCall : gatherCallers(PopConfigFunc)) { + for (CallBase *PopCall : gatherCallers(PopConfigFunc)) { Function *TheFunc = PopCall->getFunction(); auto Allocas = FuncAllocas.lookup(TheFunc); if (Allocas.empty()) { continue; } - CallInst *KernelLaunch = PopCall; + CallBase *KernelLaunch = PopCall; Instruction *It = PopCall; do { It = It->getNextNonDebugInstruction(); @@ -242,12 +250,12 @@ void fixup(Module &M) { PopCall->eraseFromParent(); } - for (CallInst *PushCall : gatherCallers(PushConfigFunc)) { + for (CallBase *PushCall : gatherCallers(PushConfigFunc)) { // Replace with success PushCall->replaceAllUsesWith(ConstantInt::get(IntegerType::get(PushCall->getContext(), 32), 0)); PushCall->eraseFromParent(); } - for (CallInst *CI : CoercedKernels) { + for (CallBase *CI : CoercedKernels) { IRBuilder<> Builder(CI); auto FuncPtr = CI->getArgOperand(0); auto GridDim1 = CI->getArgOperand(1); @@ -296,7 +304,13 @@ class ReactantBase { bool run(Module &M) { bool changed = true; + if (getenv("DEBUG_REACTANT")) + llvm::errs() <<" pre fix: " << M << "\n"; fixup(M); + auto discard = M.getContext().shouldDiscardValueNames(); + M.getContext().setDiscardValueNames(false); + if (getenv("DEBUG_REACTANT")) + llvm::errs() <<" post fix: " << M << "\n"; for (auto bin : gpubins) { SMDiagnostic Err; @@ -356,6 +370,7 @@ class ReactantBase { if (!F.empty()) F.setLinkage(Function::LinkageTypes::InternalLinkage); } + llvm::errs() << " mod2: " << *mod2 << "\n"; SmallVector toInternalize; if (auto RF = M.getFunction("__cudaRegisterFunction")) { for (auto U : make_early_inc_range(RF->users())) { @@ -404,6 +419,8 @@ class ReactantBase { } } + llvm::errs() << "post link: " << M << "\n"; + for (Function &F : make_early_inc_range(M)) { if (!F.empty()) continue; if (F.getName() == "cudaMalloc") { @@ -514,6 +531,7 @@ class ReactantBase { L.linkInModule(std::move(llvmModule)); M.getContext().setDiagnosticHandler(std::move(handler)); } + M.getContext().setDiscardValueNames(discard); return changed; } From 7efc20af4e33c2178acb2c00f0c69fa33a33e1f6 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 26 Jun 2025 21:34:09 -0400 Subject: [PATCH 12/13] fixup debug info --- enzyme/Enzyme/Enzyme.cpp | 372 +++++++++++++++++++-------------------- 1 file changed, 186 insertions(+), 186 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 14d873893ede..597d78545d33 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -30,7 +30,7 @@ #include #include "llvm/ADT/StringRef.h" - #include +#include #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -52,8 +52,8 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Passes/PassBuilder.h" #include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassBuilder.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -80,11 +80,10 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" - #include "llvm/Transforms/IPO/GlobalOpt.h" -#include "llvm/Linker/Linker.h" #include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" using namespace llvm; #ifdef DEBUG_TYPE @@ -94,11 +93,11 @@ using namespace llvm; llvm::cl::opt Passes("raising-plugin-path", cl::init(""), cl::Hidden, - cl::desc("Print before and after fns for autodiff")); + cl::desc("Print before and after fns for autodiff")); namespace { - constexpr char cudaLaunchSymbolName[] = "cudaLaunchKernel"; +constexpr char cudaLaunchSymbolName[] = "cudaLaunchKernel"; constexpr char kernelPrefix[] = "__mlir_launch_kernel_"; constexpr char kernelCoercedPrefix[] = "__mlir_launch_coerced_kernel_"; @@ -137,11 +136,13 @@ void fixup(Module &M) { BlockDim2, SharedMemSize, StreamPtr, }; auto StubFunc = cast(CI->getArgOperand(0)); - + size_t idx = 0; for (auto &Arg : StubFunc->args()) { - auto gep = Builder.CreateConstInBoundsGEP1_64(llvm::PointerType::getUnqual(CI->getContext()), ArgPtr, idx); - auto ld = Builder.CreateLoad(llvm::PointerType::getUnqual(CI->getContext()), gep); + auto gep = Builder.CreateConstInBoundsGEP1_64( + llvm::PointerType::getUnqual(CI->getContext()), ArgPtr, idx); + auto ld = Builder.CreateLoad( + llvm::PointerType::getUnqual(CI->getContext()), gep); ld = Builder.CreateLoad(Arg.getType(), ld); Args.push_back(ld); idx++; @@ -206,14 +207,13 @@ void fixup(Module &M) { IRB.CreateAlloca(IRB.getInt64Ty(), nullptr, "shmem_size")); Allocas.push_back(IRB.CreateAlloca(IRB.getPtrTy(), nullptr, "stream")); FuncAllocas.insert_or_assign(TheFunc, Allocas); - llvm::errs() <<" CI: making allocas for " << *CI << "\n"; + llvm::errs() << " CI: making allocas for " << *CI << "\n"; } IRB.SetInsertPoint(CI); if (CI->arg_size() != Allocas.size()) { llvm::errs() << " size mismatch on: " << *CI << "\n"; } - for (auto [Arg, Alloca] : - llvm::zip_equal(CI->args(), Allocas)) + for (auto [Arg, Alloca] : llvm::zip_equal(CI->args(), Allocas)) IRB.CreateStore(Arg, Alloca); } auto PopConfigFunc = M.getFunction(cudaPopConfigName); @@ -252,7 +252,8 @@ void fixup(Module &M) { for (CallBase *PushCall : gatherCallers(PushConfigFunc)) { // Replace with success - PushCall->replaceAllUsesWith(ConstantInt::get(IntegerType::get(PushCall->getContext(), 32), 0)); + PushCall->replaceAllUsesWith( + ConstantInt::get(IntegerType::get(PushCall->getContext(), 32), 0)); PushCall->eraseFromParent(); } for (CallBase *CI : CoercedKernels) { @@ -298,188 +299,193 @@ void fixup(Module &M) { class ReactantBase { public: std::vector gpubins; - ReactantBase(const std::vector &gpubins) : gpubins(gpubins) { - } + ReactantBase(const std::vector &gpubins) : gpubins(gpubins) {} bool run(Module &M) { bool changed = true; if (getenv("DEBUG_REACTANT")) - llvm::errs() <<" pre fix: " << M << "\n"; + llvm::errs() << " pre fix: " << M << "\n"; fixup(M); auto discard = M.getContext().shouldDiscardValueNames(); M.getContext().setDiscardValueNames(false); if (getenv("DEBUG_REACTANT")) - llvm::errs() <<" post fix: " << M << "\n"; - + llvm::errs() << " post fix: " << M << "\n"; + for (auto bin : gpubins) { SMDiagnostic Err; auto mod2 = llvm::parseIRFile(bin + ".re_export", Err, M.getContext()); if (!mod2) { - Err.print(/*ProgName=*/"LLVMToMLIR", llvm::errs()); - exit(1); + Err.print(/*ProgName=*/"LLVMToMLIR", llvm::errs()); + exit(1); } for (std::string T : {"", "f"}) { - for (std::string name : - {"sin", "cos", "tan", "log2", "exp", "exp2", - "exp10", "cosh", "sinh", "tanh", "atan2", "atan", - "asin", "acos", "log", "log10", "log1p", "acosh", - "asinh", "atanh", "expm1", "hypot", "rhypot", "norm3d", - "rnorm3d", "norm4d", "rnorm4d", "norm", "rnorm", "cbrt", - "rcbrt", "j0", "j1", "y0", "y1", "yn", - "jn", "erf", "erfinv", "erfc", "erfcx", "erfcinv", - "normcdfinv", "normcdf", "lgamma", "ldexp", "scalbn", "frexp", - "modf", "fmod", "remainder", "remquo", "powi", "tgamma", - "round", "fdim", "ilogb", "logb", "isinf", "pow", - "sqrt", "finite", "fabs", "fmax"}) { - std::string nvname = "__nv_" + name; - std::string llname = "llvm." + name + "."; - std::string mathname = name; - - if (T == "f") { - mathname += "f"; - nvname += "f"; - llname += "f32"; - } else { - llname += "f64"; - } - - if (auto F = mod2->getFunction(llname)) { - F->deleteBody(); - } - } + for (std::string name : + {"sin", "cos", "tan", "log2", "exp", + "exp2", "exp10", "cosh", "sinh", "tanh", + "atan2", "atan", "asin", "acos", "log", + "log10", "log1p", "acosh", "asinh", "atanh", + "expm1", "hypot", "rhypot", "norm3d", "rnorm3d", + "norm4d", "rnorm4d", "norm", "rnorm", "cbrt", + "rcbrt", "j0", "j1", "y0", "y1", + "yn", "jn", "erf", "erfinv", "erfc", + "erfcx", "erfcinv", "normcdfinv", "normcdf", "lgamma", + "ldexp", "scalbn", "frexp", "modf", "fmod", + "remainder", "remquo", "powi", "tgamma", "round", + "fdim", "ilogb", "logb", "isinf", "pow", + "sqrt", "finite", "fabs", "fmax"}) { + std::string nvname = "__nv_" + name; + std::string llname = "llvm." + name + "."; + std::string mathname = name; + + if (T == "f") { + mathname += "f"; + nvname += "f"; + llname += "f32"; + } else { + llname += "f64"; + } + + if (auto F = mod2->getFunction(llname)) { + F->deleteBody(); + } + } } - { - - - PassBuilder PB; - LoopAnalysisManager LAM; - FunctionAnalysisManager FAM; - CGSCCAnalysisManager CGAM; - ModuleAnalysisManager MAM; - PB.registerModuleAnalyses(MAM); - PB.registerFunctionAnalyses(FAM); - PB.registerLoopAnalyses(LAM); - PB.registerCGSCCAnalyses(CGAM); - PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); - - GlobalOptPass().run(*mod2, MAM); + { + + PassBuilder PB; + LoopAnalysisManager LAM; + FunctionAnalysisManager FAM; + CGSCCAnalysisManager CGAM; + ModuleAnalysisManager MAM; + PB.registerModuleAnalyses(MAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.registerCGSCCAnalyses(CGAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + GlobalOptPass().run(*mod2, MAM); } for (auto &F : *mod2) { - if (!F.empty()) - F.setLinkage(Function::LinkageTypes::InternalLinkage); + if (!F.empty()) + F.setLinkage(Function::LinkageTypes::InternalLinkage); } - llvm::errs() << " mod2: " << *mod2 << "\n"; + if (getenv("DEBUG_REACTANT")) + llvm::errs() << " mod2: " << *mod2 << "\n"; + SmallVector toInternalize; if (auto RF = M.getFunction("__cudaRegisterFunction")) { for (auto U : make_early_inc_range(RF->users())) { - if (auto CI = dyn_cast(U)) { - if (CI->getCalledFunction() != RF) continue; - - Value *F2 = CI->getArgOperand(1); - Value *name = CI->getArgOperand(2); - while (auto CE = dyn_cast(F2)) { - F2 = CE->getOperand(0); - } - while (auto CE = dyn_cast(name)) { - name = CE->getOperand(0); - } - StringRef nameVal; - if (auto GV = dyn_cast(name)) - if (GV->isConstant()) - if (auto C = GV->getInitializer()) - if (auto CA = dyn_cast(C)) - if (CA->getType()->getElementType()->isIntegerTy(8) && - CA->isCString()) - nameVal = CA->getAsCString(); - auto F22 = dyn_cast(F2); - if (!F22) continue; - - if (nameVal.size()) - if (auto MF = mod2->getFunction(nameVal)) { - MF->setName(F22->getName()); - F22->deleteBody(); - MF->setCallingConv(llvm::CallingConv::C); - MF->setLinkage(Function::LinkageTypes::LinkOnceODRLinkage); - toInternalize.push_back(MF->getName().str()); - CI->eraseFromParent(); - } - } - } + if (auto CI = dyn_cast(U)) { + if (CI->getCalledFunction() != RF) + continue; + + Value *F2 = CI->getArgOperand(1); + Value *name = CI->getArgOperand(2); + while (auto CE = dyn_cast(F2)) { + F2 = CE->getOperand(0); + } + while (auto CE = dyn_cast(name)) { + name = CE->getOperand(0); + } + StringRef nameVal; + if (auto GV = dyn_cast(name)) + if (GV->isConstant()) + if (auto C = GV->getInitializer()) + if (auto CA = dyn_cast(C)) + if (CA->getType()->getElementType()->isIntegerTy(8) && + CA->isCString()) + nameVal = CA->getAsCString(); + auto F22 = dyn_cast(F2); + if (!F22) + continue; + + if (nameVal.size()) + if (auto MF = mod2->getFunction(nameVal)) { + MF->setName(F22->getName()); + F22->deleteBody(); + MF->setCallingConv(llvm::CallingConv::C); + MF->setLinkage(Function::LinkageTypes::LinkOnceODRLinkage); + toInternalize.push_back(MF->getName().str()); + CI->eraseFromParent(); + } + } + } } - auto handler = M.getContext().getDiagnosticHandler(); + auto handler = M.getContext().getDiagnosticHandler(); Linker L(M); L.linkInModule(std::move(mod2)); M.getContext().setDiagnosticHandler(std::move(handler)); for (auto name : toInternalize) - if (auto F = M.getFunction(name)) { - F->setLinkage(Function::LinkageTypes::InternalLinkage); - } + if (auto F = M.getFunction(name)) { + F->setLinkage(Function::LinkageTypes::InternalLinkage); + } } - llvm::errs() << "post link: " << M << "\n"; + if (getenv("DEBUG_REACTANT")) + llvm::errs() << "post link: " << M << "\n"; for (Function &F : make_early_inc_range(M)) { - if (!F.empty()) continue; + if (!F.empty()) + continue; if (F.getName() == "cudaMalloc") { - continue; + continue; auto entry = BasicBlock::Create(F.getContext(), "entry", &F); IRBuilder<> B(entry); } } fixup(M); - for (auto todel : {"__cuda_register_globals", "__cuda_module_ctor", "__cuda_module_dtor"}) { - if (auto F = M.getFunction(todel)) { - F->replaceAllUsesWith(Constant::getNullValue(F->getType())); - F->eraseFromParent(); - } + for (auto todel : {"__cuda_register_globals", "__cuda_module_ctor", + "__cuda_module_dtor"}) { + if (auto F = M.getFunction(todel)) { + F->replaceAllUsesWith(Constant::getNullValue(F->getType())); + F->eraseFromParent(); + } } - + if (auto GV = M.getGlobalVariable("llvm.global_ctors")) { ConstantArray *CA = dyn_cast(GV->getInitializer()); if (CA) { - bool changed = false; - SmallVector newOperands; - for (Use &OP : CA->operands()) { - if (isa(OP)) { - changed = true; - continue; - } - ConstantStruct *CS = cast(OP); - if (isa(CS->getOperand(1))) { - changed = true; - continue; - } - newOperands.push_back(CS); - } - if (changed) { - if (newOperands.size() == 0) { - GV->eraseFromParent(); - } else { - auto EltTy = newOperands[0]->getType(); - ArrayType *NewType = ArrayType::get(EltTy, newOperands.size()); - auto CT = ConstantArray::get(NewType, newOperands); - - // Create the new global variable. - GlobalVariable *NG = new GlobalVariable( - M, NewType, GV->isConstant(), GV->getLinkage(), - /*init*/ CT, /*name*/ "", GV, GV->getThreadLocalMode(), - GV->getAddressSpace()); - - NG->copyAttributesFrom(GV); - NG->takeName(GV); - GV->replaceAllUsesWith(NG); - GV->eraseFromParent(); - - } - + bool changed = false; + SmallVector newOperands; + for (Use &OP : CA->operands()) { + if (isa(OP)) { + changed = true; + continue; + } + ConstantStruct *CS = cast(OP); + if (isa(CS->getOperand(1))) { + changed = true; + continue; + } + newOperands.push_back(CS); + } + if (changed) { + if (newOperands.size() == 0) { + GV->eraseFromParent(); + } else { + auto EltTy = newOperands[0]->getType(); + ArrayType *NewType = ArrayType::get(EltTy, newOperands.size()); + auto CT = ConstantArray::get(NewType, newOperands); + + // Create the new global variable. + GlobalVariable *NG = new GlobalVariable( + M, NewType, GV->isConstant(), GV->getLinkage(), + /*init*/ CT, /*name*/ "", GV, GV->getThreadLocalMode(), + GV->getAddressSpace()); + + NG->copyAttributesFrom(GV); + NG->takeName(GV); + GV->replaceAllUsesWith(NG); + GV->eraseFromParent(); + } + } } } - } { PassBuilder PB; @@ -494,21 +500,22 @@ class ReactantBase { PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); GlobalOptPass().run(M, MAM); - } + } auto lib = dlopen(Passes.c_str(), RTLD_LAZY | RTLD_DEEPBIND); if (!lib) { - llvm::errs() << " could not open " << Passes.c_str() << " - " << dlerror() << "\n"; + llvm::errs() << " could not open " << Passes.c_str() << " - " << dlerror() + << "\n"; } auto sym = dlsym(lib, "runLLVMToMLIRRoundTrip"); - if (!sym) { + if (!sym) { llvm::errs() << " could not find sym\n"; } - auto runLLVMToMLIRRoundTrip = (std::string (*)(std::string))sym; + auto runLLVMToMLIRRoundTrip = (std::string(*)(std::string))sym; if (runLLVMToMLIRRoundTrip) { std::string MStr; llvm::raw_string_ostream ss(MStr); - ss << M; + ss << M; auto newMod = runLLVMToMLIRRoundTrip(MStr); M.dropAllReferences(); @@ -537,7 +544,7 @@ class ReactantBase { } }; -} +} // namespace #include #include @@ -545,7 +552,7 @@ class ReactantBase { #include "llvm/Passes/PassPlugin.h" class ReactantNewPM final : public ReactantBase, - public AnalysisInfoMixin { + public AnalysisInfoMixin { friend struct llvm::AnalysisInfoMixin; private: @@ -553,12 +560,12 @@ class ReactantNewPM final : public ReactantBase, public: using Result = llvm::PreservedAnalyses; - ReactantNewPM(const std::vector &gpubins) : ReactantBase(gpubins) { -} + ReactantNewPM(const std::vector &gpubins) + : ReactantBase(gpubins) {} Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { return ReactantBase::run(M) ? PreservedAnalyses::none() - : PreservedAnalyses::all(); + : PreservedAnalyses::all(); } static bool isRequired() { return true; } @@ -578,15 +585,15 @@ class ExporterNewPM final : public AnalysisInfoMixin { Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { std::string filename = firstfile + ".re_export"; - std::error_code EC; - llvm::raw_fd_ostream file(filename, EC);//, llvm::sys::fs::OF_Text); + std::error_code EC; + llvm::raw_fd_ostream file(filename, EC); //, llvm::sys::fs::OF_Text); - if (EC) { - llvm::errs() << "Error opening file: " << EC.message() << "\n"; - exit(1); - } + if (EC) { + llvm::errs() << "Error opening file: " << EC.message() << "\n"; + exit(1); + } - file << M; + file << M; return PreservedAnalyses::all(); } @@ -601,20 +608,17 @@ AnalysisKey ExporterNewPM::Key; extern "C" void registerExporter(llvm::PassBuilder &PB, std::string file) { #if LLVM_VERSION_MAJOR >= 20 - auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level, - ThinOrFullLTOPhase) + auto loadPass = + [=](ModulePassManager &MPM, OptimizationLevel Level, ThinOrFullLTOPhase) #else auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level) #endif - { - MPM.addPass(ExporterNewPM(file)); - }; + { MPM.addPass(ExporterNewPM(file)); }; // TODO need for perf reasons to move Enzyme pass to the pre vectorization. PB.registerOptimizerEarlyEPCallback(loadPass); - auto loadLTO = [loadPass](ModulePassManager &MPM, - OptimizationLevel Level) { + auto loadLTO = [loadPass](ModulePassManager &MPM, OptimizationLevel Level) { #if LLVM_VERSION_MAJOR >= 20 loadPass(MPM, Level, ThinOrFullLTOPhase::None); #else @@ -625,37 +629,34 @@ extern "C" void registerExporter(llvm::PassBuilder &PB, std::string file) { } extern "C" void registerReactantAndPassPipeline(llvm::PassBuilder &PB, - bool augment = false) { -} + bool augment = false) {} -extern "C" void registerReactant(llvm::PassBuilder &PB, std::vector gpubinaries) { +extern "C" void registerReactant(llvm::PassBuilder &PB, + std::vector gpubinaries) { llvm::errs() << " registering reactant\n"; #if LLVM_VERSION_MAJOR >= 20 - auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level, - ThinOrFullLTOPhase) + auto loadPass = + [=](ModulePassManager &MPM, OptimizationLevel Level, ThinOrFullLTOPhase) #else auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level) #endif - { - MPM.addPass(ReactantNewPM(gpubinaries)); - }; + { MPM.addPass(ReactantNewPM(gpubinaries)); }; PB.registerPipelineParsingCallback( [=](llvm::StringRef Name, llvm::ModulePassManager &MPM, - llvm::ArrayRef) { + llvm::ArrayRef) { if (Name == "reactant") { MPM.addPass(ReactantNewPM(gpubinaries)); return true; } - return false; - }); + return false; + }); // TODO need for perf reasons to move Enzyme pass to the pre vectorization. PB.registerOptimizerEarlyEPCallback(loadPass); - auto loadLTO = [loadPass](ModulePassManager &MPM, - OptimizationLevel Level) { + auto loadLTO = [loadPass](ModulePassManager &MPM, OptimizationLevel Level) { #if LLVM_VERSION_MAJOR >= 20 loadPass(MPM, Level, ThinOrFullLTOPhase::None); #else @@ -673,4 +674,3 @@ extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK llvmGetPassPluginInfo() { return {LLVM_PLUGIN_API_VERSION, "ReactantNewPM", "v0.1", registerReactant2}; } - From 5477278f22f8acd2fb06a29bb0aaf5eb94b27e7d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 27 Jun 2025 19:09:21 -0400 Subject: [PATCH 13/13] link --- enzyme/Enzyme/Clang/EnzymeClang.cpp | 5 +++-- enzyme/Enzyme/Enzyme.cpp | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index f32eca62e03d..fda59cbc8a03 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -87,7 +87,7 @@ struct Visitor : public RecursiveASTVisitor { } }; -extern "C" void registerReactant(llvm::PassBuilder &PB, std::vector gpubins); +extern "C" void registerReactant(llvm::PassBuilder &PB, std::vector gpubins, std::string outfile); extern "C" void registerExporter(llvm::PassBuilder &PB, std::string file); @@ -124,8 +124,9 @@ class EnzymePlugin final : public clang::ASTConsumer { gpubins.push_back(inFile); //gpubins.push_back(CGOpts.CudaGpuBinaryFileName); } + std::string file = CI.getFrontendOpts().OutputFile; CGOpts.PassBuilderCallbacks.push_back([=](llvm::PassBuilder &PB) { - registerReactant(PB, gpubins); + registerReactant(PB, gpubins, file); }); } diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 597d78545d33..0c35f2653a57 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -299,7 +299,8 @@ void fixup(Module &M) { class ReactantBase { public: std::vector gpubins; - ReactantBase(const std::vector &gpubins) : gpubins(gpubins) {} + std::string outfile; + ReactantBase(const std::vector &gpubins, std::string outfile) : gpubins(gpubins), outfile(outfile) {} bool run(Module &M) { bool changed = true; @@ -511,12 +512,12 @@ class ReactantBase { if (!sym) { llvm::errs() << " could not find sym\n"; } - auto runLLVMToMLIRRoundTrip = (std::string(*)(std::string))sym; + auto runLLVMToMLIRRoundTrip = (std::string(*)(std::string, std::string))sym; if (runLLVMToMLIRRoundTrip) { std::string MStr; llvm::raw_string_ostream ss(MStr); ss << M; - auto newMod = runLLVMToMLIRRoundTrip(MStr); + auto newMod = runLLVMToMLIRRoundTrip(MStr, outfile); M.dropAllReferences(); M.getGlobalList().clear(); @@ -560,8 +561,8 @@ class ReactantNewPM final : public ReactantBase, public: using Result = llvm::PreservedAnalyses; - ReactantNewPM(const std::vector &gpubins) - : ReactantBase(gpubins) {} + ReactantNewPM(const std::vector &gpubins, std::string outfile) + : ReactantBase(gpubins, outfile) {} Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { return ReactantBase::run(M) ? PreservedAnalyses::none() @@ -632,7 +633,7 @@ extern "C" void registerReactantAndPassPipeline(llvm::PassBuilder &PB, bool augment = false) {} extern "C" void registerReactant(llvm::PassBuilder &PB, - std::vector gpubinaries) { + std::vector gpubinaries, std::string outfile) { llvm::errs() << " registering reactant\n"; #if LLVM_VERSION_MAJOR >= 20 @@ -641,13 +642,13 @@ extern "C" void registerReactant(llvm::PassBuilder &PB, #else auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level) #endif - { MPM.addPass(ReactantNewPM(gpubinaries)); }; + { MPM.addPass(ReactantNewPM(gpubinaries, outfile)); }; PB.registerPipelineParsingCallback( [=](llvm::StringRef Name, llvm::ModulePassManager &MPM, llvm::ArrayRef) { if (Name == "reactant") { - MPM.addPass(ReactantNewPM(gpubinaries)); + MPM.addPass(ReactantNewPM(gpubinaries, outfile)); return true; } return false; @@ -667,7 +668,7 @@ extern "C" void registerReactant(llvm::PassBuilder &PB, } extern "C" void registerReactant2(llvm::PassBuilder &PB) { - registerReactant(PB, {}); + registerReactant(PB, {}, ""); } extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK