From bc330e0a3cbc7ae1702a6b1139901d4404e0bbd7 Mon Sep 17 00:00:00 2001 From: "Zhang, Winston" Date: Tue, 9 Sep 2025 09:44:46 -0700 Subject: [PATCH 1/3] [UR][L0] urProgramSetSpecializationConstants to returns error Now urProgramSpcializationConstants will return UR_RESULT_ERROR_INVALID_SPEC_ID when the incorrect id is used. Signed-off-by: Zhang, Winston --- .../source/adapters/level_zero/program.cpp | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index 6b8fa9fff2db2..4df40cb0ce437 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -994,6 +994,67 @@ ur_result_t urProgramCreateWithNativeHandle( return UR_RESULT_SUCCESS; } +// Helper function to validate if a specialization constant ID exists in SPIR-V +static bool isValidSpecConstantId(const uint8_t *spirvCode, size_t spirvSize, + uint32_t specId) { + if (!spirvCode || spirvSize < 20) { + return false; // Invalid SPIR-V + } + + // Check SPIR-V magic number + const uint32_t *words = reinterpret_cast(spirvCode); + if (words[0] != 0x07230203) { + return false; // Invalid SPIR-V magic number + } + + // Parse SPIR-V header + // words[0] = magic number + // words[1] = version + // words[2] = generator magic number + // words[3] = bound on all ids + // words[4] = schema (0) + size_t headerSize = 5; + if (spirvSize < headerSize * sizeof(uint32_t)) { + return false; + } + + // Parse instructions looking for OpSpecConstant* instructions + size_t pos = headerSize; + const uint32_t *end = words + (spirvSize / sizeof(uint32_t)); + + while (pos < (spirvSize / sizeof(uint32_t))) { + if (pos >= (end - words)) + break; + + uint32_t instruction = words[pos]; + uint16_t opcode = instruction & 0xFFFF; + uint16_t length = instruction >> 16; + + if (length == 0 || pos + length > (end - words)) { + break; // Invalid instruction + } + + // OpSpecConstantTrue = 48, OpSpecConstantFalse = 49, OpSpecConstant = 50 + // OpSpecConstantComposite = 51, OpSpecConstantOp = 52 + if (opcode >= 48 && opcode <= 52) { + if (length >= + 3) { // All OpSpecConstant* instructions have at least 3 words + // words[pos + 0] = instruction header + // words[pos + 1] = result type id + // words[pos + 2] = result id (this is the specialization constant id) + uint32_t resultId = words[pos + 2]; + if (resultId == specId) { + return true; + } + } + } + + pos += length; + } + + return false; +} + ur_result_t urProgramSetSpecializationConstants( /// [in] handle of the Program object ur_program_handle_t Program, @@ -1004,6 +1065,17 @@ ur_result_t urProgramSetSpecializationConstants( const ur_specialization_constant_info_t *SpecConstants) { std::scoped_lock Guard(Program->Mutex); + // Validate each specialization constant ID against the SPIR-V program + for (uint32_t SpecIt = 0; SpecIt < Count; SpecIt++) { + uint32_t SpecId = SpecConstants[SpecIt].id; + + // Validate the spec constant ID exists in the SPIR-V binary + if (!isValidSpecConstantId(Program->getCode(), Program->getCodeSize(), + SpecId)) { + return UR_RESULT_ERROR_INVALID_SPEC_ID; + } + } + // Remember the value of this specialization constant until the program is // built. Note that we only save the pointer to the buffer that contains the // value. The caller is responsible for maintaining storage for this buffer. From c9638d88c8f0aa351c97f6bc7e362ba9c8e52216 Mon Sep 17 00:00:00 2001 From: "Zhang, Winston" Date: Tue, 9 Sep 2025 10:04:42 -0700 Subject: [PATCH 2/3] int to size_t Signed-off-by: Zhang, Winston --- .../source/adapters/level_zero/program.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index 4df40cb0ce437..8b394d9aac60b 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -1020,25 +1020,25 @@ static bool isValidSpecConstantId(const uint8_t *spirvCode, size_t spirvSize, // Parse instructions looking for OpSpecConstant* instructions size_t pos = headerSize; - const uint32_t *end = words + (spirvSize / sizeof(uint32_t)); + const size_t totalWords = spirvSize / sizeof(uint32_t); - while (pos < (spirvSize / sizeof(uint32_t))) { - if (pos >= (end - words)) + while (pos < totalWords) { + if (pos >= totalWords) break; uint32_t instruction = words[pos]; uint16_t opcode = instruction & 0xFFFF; uint16_t length = instruction >> 16; - if (length == 0 || pos + length > (end - words)) { + if (length == 0 || pos + length > totalWords) { break; // Invalid instruction } // OpSpecConstantTrue = 48, OpSpecConstantFalse = 49, OpSpecConstant = 50 // OpSpecConstantComposite = 51, OpSpecConstantOp = 52 if (opcode >= 48 && opcode <= 52) { - if (length >= - 3) { // All OpSpecConstant* instructions have at least 3 words + if (length >= 3) { + // All OpSpecConstant* instructions have at least 3 words // words[pos + 0] = instruction header // words[pos + 1] = result type id // words[pos + 2] = result id (this is the specialization constant id) From 03e16d2e80c3adbf44dc3ced63e2247462a44f62 Mon Sep 17 00:00:00 2001 From: "Zhang, Winston" Date: Tue, 9 Sep 2025 10:12:57 -0700 Subject: [PATCH 3/3] check specId Signed-off-by: Zhang, Winston --- .../source/adapters/level_zero/program.cpp | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index 8b394d9aac60b..86e4c6fa52c92 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -1018,7 +1018,8 @@ static bool isValidSpecConstantId(const uint8_t *spirvCode, size_t spirvSize, return false; } - // Parse instructions looking for OpSpecConstant* instructions + // Parse instructions looking for OpDecorate instructions with SpecId + // decoration size_t pos = headerSize; const size_t totalWords = spirvSize / sizeof(uint32_t); @@ -1034,16 +1035,14 @@ static bool isValidSpecConstantId(const uint8_t *spirvCode, size_t spirvSize, break; // Invalid instruction } - // OpSpecConstantTrue = 48, OpSpecConstantFalse = 49, OpSpecConstant = 50 - // OpSpecConstantComposite = 51, OpSpecConstantOp = 52 - if (opcode >= 48 && opcode <= 52) { - if (length >= 3) { - // All OpSpecConstant* instructions have at least 3 words - // words[pos + 0] = instruction header - // words[pos + 1] = result type id - // words[pos + 2] = result id (this is the specialization constant id) - uint32_t resultId = words[pos + 2]; - if (resultId == specId) { + // OpDecorate = 71, and we need decoration SpecId = 1 + if (opcode == 71 && length >= 4) { + // OpDecorate with at least target_id, + // decoration, and extra operand + uint32_t decoration = words[pos + 2]; + if (decoration == 1) { // SpecId decoration + uint32_t actualSpecId = words[pos + 3]; + if (actualSpecId == specId) { return true; } }