Skip to content

[HLSL][SPRIV] Handle signed RWBuffer correctly #144774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions clang/lib/CodeGen/Targets/SPIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CommonSPIRTargetCodeGenInfo : public TargetCodeGenInfo {
const SmallVector<int32_t> *Packoffsets = nullptr) const override;
llvm::Type *getSPIRVImageTypeFromHLSLResource(
const HLSLAttributedResourceType::Attributes &attributes,
llvm::Type *ElementType, llvm::LLVMContext &Ctx) const;
QualType SampledType, CodeGenModule &CGM) const;
void
setOCLKernelStubCallingConvention(const FunctionType *&FT) const override;
};
Expand Down Expand Up @@ -483,12 +483,12 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType(
assert(!ResAttrs.IsROV &&
"Rasterizer order views not implemented for SPIR-V yet");

llvm::Type *ElemType = CGM.getTypes().ConvertType(ContainedTy);
if (!ResAttrs.RawBuffer) {
// convert element type
return getSPIRVImageTypeFromHLSLResource(ResAttrs, ElemType, Ctx);
return getSPIRVImageTypeFromHLSLResource(ResAttrs, ContainedTy, CGM);
}

llvm::Type *ElemType = CGM.getTypes().ConvertType(ContainedTy);
llvm::ArrayType *RuntimeArrayType = llvm::ArrayType::get(ElemType, 0);
uint32_t StorageClass = /* StorageBuffer storage class */ 12;
bool IsWritable = ResAttrs.ResourceClass == llvm::dxil::ResourceClass::UAV;
Expand Down Expand Up @@ -516,13 +516,18 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType(
}

llvm::Type *CommonSPIRTargetCodeGenInfo::getSPIRVImageTypeFromHLSLResource(
const HLSLAttributedResourceType::Attributes &attributes,
llvm::Type *ElementType, llvm::LLVMContext &Ctx) const {
const HLSLAttributedResourceType::Attributes &attributes, QualType Ty,
CodeGenModule &CGM) const {
llvm::LLVMContext &Ctx = CGM.getLLVMContext();

if (ElementType->isVectorTy())
ElementType = ElementType->getScalarType();
Ty = Ty->getCanonicalTypeUnqualified();
if (const VectorType *V = dyn_cast<VectorType>(Ty))
Ty = V->getElementType();
assert(!Ty->isVectorType() && "We still have a vector type.");

assert((ElementType->isIntegerTy() || ElementType->isFloatingPointTy()) &&
llvm::Type *SampledType = CGM.getTypes().ConvertType(Ty);

assert((SampledType->isIntegerTy() || SampledType->isFloatingPointTy()) &&
"The element type for a SPIR-V resource must be a scalar integer or "
"floating point type.");

Expand All @@ -531,6 +536,9 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getSPIRVImageTypeFromHLSLResource(
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeImage.
SmallVector<unsigned, 6> IntParams(6, 0);

const char *Name =
Ty->isSignedIntegerType() ? "spirv.SignedImage" : "spirv.Image";

// Dim
// For now we assume everything is a buffer.
IntParams[0] = 5;
Expand All @@ -553,7 +561,9 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getSPIRVImageTypeFromHLSLResource(
// Setting to unknown for now.
IntParams[5] = 0;

return llvm::TargetExtType::get(Ctx, "spirv.Image", {ElementType}, IntParams);
llvm::TargetExtType *ImageType =
llvm::TargetExtType::get(Ctx, Name, {SampledType}, IntParams);
return ImageType;
}

std::unique_ptr<TargetCodeGenInfo>
Expand Down
10 changes: 5 additions & 5 deletions clang/test/CodeGenHLSL/builtins/RWBuffer-elementtype.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
// DXIL: %"class.hlsl::RWBuffer.11" = type { target("dx.TypedBuffer", <3 x float>, 1, 0, 0) }
// DXIL: %"class.hlsl::RWBuffer.12" = type { target("dx.TypedBuffer", <4 x i32>, 1, 0, 1) }

// SPIRV: %"class.hlsl::RWBuffer" = type { target("spirv.Image", i16, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer" = type { target("spirv.SignedImage", i16, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.0" = type { target("spirv.Image", i16, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.1" = type { target("spirv.Image", i32, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.1" = type { target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.2" = type { target("spirv.Image", i32, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.3" = type { target("spirv.Image", i64, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.3" = type { target("spirv.SignedImage", i64, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.4" = type { target("spirv.Image", i64, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.5" = type { target("spirv.Image", half, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.6" = type { target("spirv.Image", float, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.7" = type { target("spirv.Image", double, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.8" = type { target("spirv.Image", i16, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.8" = type { target("spirv.SignedImage", i16, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.9" = type { target("spirv.Image", i32, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.10" = type { target("spirv.Image", half, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.11" = type { target("spirv.Image", float, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.12" = type { target("spirv.Image", i32, 5, 2, 0, 0, 2, 0) }
// SPIRV: %"class.hlsl::RWBuffer.12" = type { target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) }

RWBuffer<int16_t> BufI16;
RWBuffer<uint16_t> BufU16;
Expand Down
8 changes: 4 additions & 4 deletions clang/test/CodeGenHLSL/builtins/RWBuffer-subscript.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ void main(unsigned GI : SV_GroupIndex) {
// CHECK: define void @main()

// DXC: %[[INPTR:.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @llvm.dx.resource.getpointer.p0.tdx.TypedBuffer_i32_1_0_1t(target("dx.TypedBuffer", i32, 1, 0, 1) %{{.*}}, i32 %{{.*}})
// SPIRV: %[[INPTR:.*]] = call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.Image_i32_5_2_0_0_2_0t(target("spirv.Image", i32, 5, 2, 0, 0, 2, 0) %{{.*}}, i32 %{{.*}})
// SPIRV: %[[INPTR:.*]] = call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %{{.*}}, i32 %{{.*}})
// CHECK: %[[LOAD:.*]] = load i32, ptr {{.*}}%[[INPTR]]
// DXC: %[[OUTPTR:.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @llvm.dx.resource.getpointer.p0.tdx.TypedBuffer_i32_1_0_1t(target("dx.TypedBuffer", i32, 1, 0, 1) %{{.*}}, i32 %{{.*}})
// SPIRV: %[[OUTPTR:.*]] = call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.Image_i32_5_2_0_0_2_0t(target("spirv.Image", i32, 5, 2, 0, 0, 2, 0) %{{.*}}, i32 %{{.*}})
// SPIRV: %[[OUTPTR:.*]] = call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %{{.*}}, i32 %{{.*}})
// CHECK: store i32 %[[LOAD]], ptr {{.*}}%[[OUTPTR]]
Out[GI] = In[GI];

// DXC: %[[INPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.TypedBuffer_i32_1_0_1t(target("dx.TypedBuffer", i32, 1, 0, 1) %{{.*}}, i32 %{{.*}})
// SPIRV: %[[INPTR:.*]] = call ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.Image_i32_5_2_0_0_2_0t(target("spirv.Image", i32, 5, 2, 0, 0, 2, 0) %{{.*}}, i32 %{{.*}})
// SPIRV: %[[INPTR:.*]] = call ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %{{.*}}, i32 %{{.*}})
// CHECK: %[[LOAD:.*]] = load i32, ptr {{.*}}%[[INPTR]]
// DXC: %[[OUTPTR:.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @llvm.dx.resource.getpointer.p0.tdx.TypedBuffer_i32_1_0_1t(target("dx.TypedBuffer", i32, 1, 0, 1) %{{.*}}, i32 %{{.*}})
// SPIRV: %[[OUTPTR:.*]] = call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.Image_i32_5_2_0_0_2_0t(target("spirv.Image", i32, 5, 2, 0, 0, 2, 0) %{{.*}}, i32 %{{.*}})
// SPIRV: %[[OUTPTR:.*]] = call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %{{.*}}, i32 %{{.*}})
// CHECK: store i32 %[[LOAD]], ptr {{.*}}%[[OUTPTR]]
Out[GI] = In.Load(GI);
}
7 changes: 7 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ using target extension types and are represented as follows:
SPIR-V Type LLVM type name LLVM type arguments
================== ======================= ===========================================================================================
OpTypeImage ``spirv.Image`` sampled type, dimensionality, depth, arrayed, MS, sampled, image format, [access qualifier]
OpTypeImage ``spirv.SignedImage`` sampled type, dimensionality, depth, arrayed, MS, sampled, image format, [access qualifier]
OpTypeSampler ``spirv.Sampler`` (none)
OpTypeSampledImage ``spirv.SampledImage`` sampled type, dimensionality, depth, arrayed, MS, sampled, image format, [access qualifier]
OpTypeEvent ``spirv.Event`` (none)
Expand All @@ -275,6 +276,12 @@ parameters of its underlying image type, so that a sampled image for the
previous type has the representation
``target("spirv.SampledImage, void, 1, 1, 0, 0, 0, 0, 0)``.

The differences between ``spirv.Image`` and ``spirv.SignedImage`` is that the
backend will generate code assuming that the format of the image is a signed
integer instead of unsigned. This is required because llvm-ir will create the
same sampled type for signed and unsigned integers. If the image format is
unknown, the backend cannot distinguish the two case.

See `wg-hlsl proposal 0018 <https://github.com/llvm/wg-hlsl/blob/main/proposals/0018-spirv-resource-representation.md>`_
for details on ``spirv.VulkanBuffer``.

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/IR/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ struct TargetTypeInfo {
static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
LLVMContext &C = Ty->getContext();
StringRef Name = Ty->getName();
if (Name == "spirv.Image")
if (Name == "spirv.Image" || Name == "spirv.SignedImage")
return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::CanBeGlobal,
TargetExtType::CanBeLocal);
if (Name == "spirv.Type") {
Expand Down
38 changes: 3 additions & 35 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3086,43 +3086,11 @@ static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
ExtensionType->getIntParameter(3), true);
}

static SPIRVType *
getImageType(const TargetExtType *ExtensionType,
const SPIRV::AccessQualifier::AccessQualifier Qualifier,
MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {
assert(ExtensionType->getNumTypeParameters() == 1 &&
"SPIR-V image builtin type must have sampled type parameter!");
const SPIRVType *SampledType =
GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
assert((ExtensionType->getNumIntParameters() == 7 ||
ExtensionType->getNumIntParameters() == 6) &&
"Invalid number of parameters for SPIR-V image builtin!");

SPIRV::AccessQualifier::AccessQualifier accessQualifier =
SPIRV::AccessQualifier::None;
if (ExtensionType->getNumIntParameters() == 7) {
accessQualifier = Qualifier == SPIRV::AccessQualifier::WriteOnly
? SPIRV::AccessQualifier::WriteOnly
: SPIRV::AccessQualifier::AccessQualifier(
ExtensionType->getIntParameter(6));
}

// Create or get an existing type from GlobalRegistry.
return GR->getOrCreateOpTypeImage(
MIRBuilder, SampledType,
SPIRV::Dim::Dim(ExtensionType->getIntParameter(0)),
ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
ExtensionType->getIntParameter(3), ExtensionType->getIntParameter(4),
SPIRV::ImageFormat::ImageFormat(ExtensionType->getIntParameter(5)),
accessQualifier);
}

static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
SPIRVType *OpaqueImageType = getImageType(
OpaqueType, SPIRV::AccessQualifier::ReadOnly, MIRBuilder, GR);
SPIRVType *OpaqueImageType = GR->getImageType(
OpaqueType, SPIRV::AccessQualifier::ReadOnly, MIRBuilder);
// Create or get an existing type from GlobalRegistry.
return GR->getOrCreateOpTypeSampledImage(OpaqueImageType, MIRBuilder);
}
Expand Down Expand Up @@ -3293,7 +3261,7 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,

switch (TypeRecord->Opcode) {
case SPIRV::OpTypeImage:
TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
TargetType = GR->getImageType(BuiltinType, AccessQual, MIRBuilder);
break;
case SPIRV::OpTypePipe:
TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,7 @@ def : BuiltinType<"spirv.Event", OpTypeEvent>;
def : BuiltinType<"spirv.Sampler", OpTypeSampler>;
def : BuiltinType<"spirv.DeviceEvent", OpTypeDeviceEvent>;
def : BuiltinType<"spirv.Image", OpTypeImage>;
def : BuiltinType<"spirv.SignedImage", OpTypeImage>;
def : BuiltinType<"spirv.SampledImage", OpTypeSampledImage>;
def : BuiltinType<"spirv.Pipe", OpTypePipe>;
def : BuiltinType<"spirv.CooperativeMatrixKHR", OpTypeCooperativeMatrixKHR>;
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,8 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
auto *II = dyn_cast<IntrinsicInst>(I);
if (II && II->getIntrinsicID() == Intrinsic::spv_resource_getpointer) {
auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
if (HandleType->getTargetExtName() == "spirv.Image") {
if (HandleType->getTargetExtName() == "spirv.Image" ||
HandleType->getTargetExtName() == "spirv.SignedImage") {
if (II->hasOneUse()) {
auto *U = *II->users().begin();
Ty = cast<Instruction>(U)->getAccessType();
Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,40 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateLayoutType(
return SPIRVStructType;
}

SPIRVType *SPIRVGlobalRegistry::getImageType(
const TargetExtType *ExtensionType,
const SPIRV::AccessQualifier::AccessQualifier Qualifier,
MachineIRBuilder &MIRBuilder) {
assert(ExtensionType->getNumTypeParameters() == 1 &&
"SPIR-V image builtin type must have sampled type parameter!");
const SPIRVType *SampledType =
getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
assert((ExtensionType->getNumIntParameters() == 7 ||
ExtensionType->getNumIntParameters() == 6) &&
"Invalid number of parameters for SPIR-V image builtin!");

SPIRV::AccessQualifier::AccessQualifier accessQualifier =
SPIRV::AccessQualifier::None;
if (ExtensionType->getNumIntParameters() == 7) {
accessQualifier = Qualifier == SPIRV::AccessQualifier::WriteOnly
? SPIRV::AccessQualifier::WriteOnly
: SPIRV::AccessQualifier::AccessQualifier(
ExtensionType->getIntParameter(6));
}

// Create or get an existing type from GlobalRegistry.
SPIRVType *R = getOrCreateOpTypeImage(
MIRBuilder, SampledType,
SPIRV::Dim::Dim(ExtensionType->getIntParameter(0)),
ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
ExtensionType->getIntParameter(3), ExtensionType->getIntParameter(4),
SPIRV::ImageFormat::ImageFormat(ExtensionType->getIntParameter(5)),
accessQualifier);
SPIRVToLLVMType[R] = ExtensionType;
return R;
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
Expand Down
15 changes: 10 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,13 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
MachineIRBuilder &MIRBuilder);
bool hasBlockDecoration(SPIRVType *Type) const;

SPIRVType *
getOrCreateOpTypeImage(MachineIRBuilder &MIRBuilder, SPIRVType *SampledType,
SPIRV::Dim::Dim Dim, uint32_t Depth, uint32_t Arrayed,
uint32_t Multisampled, uint32_t Sampled,
SPIRV::ImageFormat::ImageFormat ImageFormat,
SPIRV::AccessQualifier::AccessQualifier AccQual);

public:
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
Expand Down Expand Up @@ -607,11 +614,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
const TargetExtType *T, bool EmitIr = false);

SPIRVType *
getOrCreateOpTypeImage(MachineIRBuilder &MIRBuilder, SPIRVType *SampledType,
SPIRV::Dim::Dim Dim, uint32_t Depth, uint32_t Arrayed,
uint32_t Multisampled, uint32_t Sampled,
SPIRV::ImageFormat::ImageFormat ImageFormat,
SPIRV::AccessQualifier::AccessQualifier AccQual);
getImageType(const TargetExtType *ExtensionType,
const SPIRV::AccessQualifier::AccessQualifier Qualifier,
MachineIRBuilder &MIRBuilder);

SPIRVType *getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder);

Expand Down
Loading