Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions unified-runtime/source/adapters/level_zero/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,66 @@ ur_result_t urProgramCreateWithNativeHandle(
return UR_RESULT_SUCCESS;
}

// Helper function to validate if a specialization constant ID exists in SPIR-V
static bool isValidSpecConstantId(const uint8_t *spirvCode, size_t spirvSize,
uint32_t specId) {
if (!spirvCode || spirvSize < 20) {
return false; // Invalid SPIR-V
}

// Check SPIR-V magic number
const uint32_t *words = reinterpret_cast<const uint32_t *>(spirvCode);
if (words[0] != 0x07230203) {
return false; // Invalid SPIR-V magic number
}

// Parse SPIR-V header
// words[0] = magic number
// words[1] = version
// words[2] = generator magic number
// words[3] = bound on all ids
// words[4] = schema (0)
size_t headerSize = 5;
if (spirvSize < headerSize * sizeof(uint32_t)) {
return false;
}

// Parse instructions looking for OpDecorate instructions with SpecId
// decoration
size_t pos = headerSize;
const size_t totalWords = spirvSize / sizeof(uint32_t);

while (pos < totalWords) {
if (pos >= totalWords)
break;

uint32_t instruction = words[pos];
uint16_t opcode = instruction & 0xFFFF;
uint16_t length = instruction >> 16;

if (length == 0 || pos + length > totalWords) {
break; // Invalid instruction
}

// OpDecorate = 71, and we need decoration SpecId = 1
if (opcode == 71 && length >= 4) {
// OpDecorate with at least target_id,
// decoration, and extra operand
uint32_t decoration = words[pos + 2];
if (decoration == 1) { // SpecId decoration
uint32_t actualSpecId = words[pos + 3];
if (actualSpecId == specId) {
return true;
}
}
}

pos += length;
}

return false;
}

ur_result_t urProgramSetSpecializationConstants(
/// [in] handle of the Program object
ur_program_handle_t Program,
Expand All @@ -1004,6 +1064,17 @@ ur_result_t urProgramSetSpecializationConstants(
const ur_specialization_constant_info_t *SpecConstants) {
std::scoped_lock<ur_shared_mutex> Guard(Program->Mutex);

// Validate each specialization constant ID against the SPIR-V program
for (uint32_t SpecIt = 0; SpecIt < Count; SpecIt++) {
uint32_t SpecId = SpecConstants[SpecIt].id;

// Validate the spec constant ID exists in the SPIR-V binary
if (!isValidSpecConstantId(Program->getCode(), Program->getCodeSize(),
SpecId)) {
return UR_RESULT_ERROR_INVALID_SPEC_ID;
}
}

// Remember the value of this specialization constant until the program is
// built. Note that we only save the pointer to the buffer that contains the
// value. The caller is responsible for maintaining storage for this buffer.
Expand Down
Loading