From 63464140963512a5fb2f0d7e7b37a93b5e4d31f7 Mon Sep 17 00:00:00 2001 From: Zhang Xiangze Date: Fri, 7 Jun 2024 16:55:45 +0530 Subject: [PATCH 1/3] Load and convert scales/zeros with groupsize --- src/common/transformer_ctx.h | 10 +- src/layers/attention.h | 32 +-- src/layers/decoder_block.h | 223 ++++-------------- src/layers/dist_linear.h | 8 +- src/layers/mlp_llama.h | 41 ++-- src/layers/mlp_standard.h | 8 +- src/models/common_decoder.h | 7 +- src/utils/matmul_helper.h | 60 ++--- src/utils/weight_util.h | 88 ++++++- src/xfastertransformer/tools/qwen2_convert.py | 3 +- 10 files changed, 223 insertions(+), 257 deletions(-) diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 27b777bc..31b0fa78 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -86,6 +86,9 @@ struct DecoderContext { float attFactor; float epsilon; + // quantization configuration + int groupsize; + // rope scaling parameters RopeParams *ropeParamsPtr; @@ -132,7 +135,7 @@ struct DecoderContext { DecoderContext(int _layers, int _hiddenSize, int _headSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act, float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength, int _splitIdx, int _splits, MMHelper *mmHelper, void *device = nullptr, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr, - bool _useLogN = true, bool _useNTK = true, int numThreads = 0) + bool _useLogN = true, bool _useNTK = true, int numThreads = 0, int _groupsize = -1) : layers(_layers) , hiddenSize(_hiddenSize) , attHeadSize(_headSize) @@ -153,7 +156,8 @@ struct DecoderContext { , ppRank(_ppRank) , tpSize(_splits) , tpRank(_splitIdx) - , epsilon(epsilon) { + , epsilon(epsilon) + , groupsize(_groupsize) { if (attHeadNum != 0) { this->attFactor = 1 / sqrtf(attHeadSize); } @@ -325,4 +329,4 @@ struct DecoderContext { if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); #endif } -}; \ No newline at end of file +}; diff --git a/src/layers/attention.h b/src/layers/attention.h index 3c8ecb96..ad3544cf 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -123,23 +123,27 @@ class Attention { float *concatScale = nullptr; float *concatZero = nullptr; if constexpr (std::is_same_v || std::is_same_v) { - concatScale = (float *)malloc(responsibleCols * sizeof(float)); - concatZero = (float *)malloc(responsibleCols * sizeof(float)); - memcpy(concatScale, queryScale + this->startQHead * headSize, qResponsibleCols * sizeof(float)); - memcpy(concatScale + qResponsibleCols, keyScale + this->startKVHead * headSize, + int qkvStride = (ctx->attHeadNum + ctx->kvHeadNum + ctx->kvHeadNum) * ctx->attHeadSize; + int groups = ctx->groupsize == -1 ? 1 : hiddenSize / ctx->groupsize; + concatScale = (float *)malloc(groups * responsibleCols * sizeof(float)); + concatZero = (float *)malloc(groups * responsibleCols * sizeof(float)); + for (int i = 0; i < groups; ++i) { + memcpy(concatScale + i * responsibleCols, queryScale + i * qkvStride + this->startQHead * headSize, qResponsibleCols * sizeof(float)); + memcpy(concatScale + i * responsibleCols + qResponsibleCols, keyScale + i * qkvStride + this->startKVHead * headSize, kvResponsibleCols * sizeof(float)); - memcpy(concatScale + qResponsibleCols + kvResponsibleCols, valueScale + this->startKVHead * headSize, + memcpy(concatScale + i * responsibleCols + qResponsibleCols + kvResponsibleCols, valueScale + i * qkvStride + this->startKVHead * headSize, kvResponsibleCols * sizeof(float)); - memcpy(concatZero, queryZero + this->startQHead * headSize, qResponsibleCols * sizeof(float)); - memcpy(concatZero + qResponsibleCols, keyZero + this->startKVHead * headSize, + memcpy(concatZero + i * responsibleCols, queryZero + i * qkvStride + this->startQHead * headSize, qResponsibleCols * sizeof(float)); + memcpy(concatZero + i * responsibleCols + qResponsibleCols, keyZero + i * qkvStride + this->startKVHead * headSize, kvResponsibleCols * sizeof(float)); - memcpy(concatZero + qResponsibleCols + kvResponsibleCols, valueZero + this->startKVHead * headSize, + memcpy(concatZero + i * responsibleCols + qResponsibleCols + kvResponsibleCols, valueZero + i * qkvStride + this->startKVHead * headSize, kvResponsibleCols * sizeof(float)); + } } xft::Matrix convertedqkvWeight; ctx->mmHelper->convertWeight(trans, hiddenSize, responsibleCols, concatBuf, concatScale, concatZero, - convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum); + convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum, ctx->groupsize); #ifdef XFT_GPU xft::Matrix qkvWeightT; @@ -182,7 +186,7 @@ class Attention { xft::Matrix convertedOutWeight; ctx->mmHelper->convertWeight(trans, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, attnOutWeight, attnOutScale, attnOutZero, this->startQHead * headSize, qResponsibleCols, false, convertedOutWeight, - attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, true); + attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, ctx->groupsize, true); #ifdef XFT_GPU xft::Matrix outWeightT; @@ -1183,15 +1187,15 @@ class Attention { // query, key, value weighs xft::Matrix qkvWeight; - xft::Vector qkvWeightScale; // if weight is int8 - xft::Vector qkvWeightZero; // if weight is int8 + xft::Matrix qkvWeightScale; // if weight is int8 + xft::Matrix qkvWeightZero; // if weight is int8 xft::Vector qkvWeightSum; // if weight is int8 // query, key, value bias xft::Vector qkvBias; xft::Matrix attnOutputWeight; - xft::Vector attnOutputWeightScale; // if weight is int8 - xft::Vector attnOutputWeightZero; // if weight is int8 + xft::Matrix attnOutputWeightScale; // if weight is int8 + xft::Matrix attnOutputWeightZero; // if weight is int8 xft::Vector attnOutputWeightSum; // if weight is int8 xft::Vector attnOutputBias; diff --git a/src/layers/decoder_block.h b/src/layers/decoder_block.h index b810c105..35a84384 100644 --- a/src/layers/decoder_block.h +++ b/src/layers/decoder_block.h @@ -127,16 +127,14 @@ class DecoderBlock { } private: - static bool fileExists(const std::string &filename) { - std::ifstream file(filename); - return file.good(); - } - // OriWeiT: float, int8_t or uint4x2_t template void setDecoderWeights(DecoderContext *ctx, DECODER *pdecoder, const std::string &modelPath, int layerIdx) { using xft::DataType; using xft::loadWeight; + using xft::Weight; + using xft::fileExists; + using xft::getGroupSize; const int hiddenSize = ctx->hiddenSize; const int imSize = ctx->intermediateSize; @@ -147,196 +145,71 @@ class DecoderBlock { int qSize = attHeadSize * attHeadNum; int kvSize = attHeadSize * kvHeadNum; int qkvSize = qSize + 2 * kvSize; - -#define ALLOC(size, alignment) xft::alloc((size), nullptr, (alignment)) - OriWeiT *qkvWeight = (OriWeiT *)ALLOC(hiddenSize * qkvSize * sizeof(OriWeiT), 64); - float *qkvScales = nullptr; - float *qkvZeros = nullptr; - float *qkvBias = (float *)ALLOC(qkvSize * sizeof(float), 64); - - OriWeiT *attnOutWeight = (OriWeiT *)ALLOC(qSize * hiddenSize * sizeof(OriWeiT), 64); - float *attnOutScales = nullptr; - float *attnOutZeros = nullptr; - float *attnOutBias = (float *)ALLOC(hiddenSize * sizeof(float), 64); - - OriWeiT *fc1Weight = (OriWeiT *)ALLOC(hiddenSize * imSize * mlpFactor * sizeof(OriWeiT), 64); - float *fc1Scales = nullptr; - float *fc1Zeros = nullptr; - float *fc1Bias = (float *)ALLOC(imSize * sizeof(float), 64); - - OriWeiT *fc2Weight = (OriWeiT *)ALLOC(hiddenSize * imSize * sizeof(OriWeiT), 64); - float *fc2Scales = nullptr; - float *fc2Zeros = nullptr; - float *fc2Bias = (float *)ALLOC(hiddenSize * sizeof(float), 64); - - float *ln1Gamma = (float *)ALLOC(hiddenSize * sizeof(float), 64); - float *ln1Beta = (float *)ALLOC(hiddenSize * sizeof(float), 64); - float *ln2Gamma = (float *)ALLOC(hiddenSize * sizeof(float), 64); - float *ln2Beta = (float *)ALLOC(hiddenSize * sizeof(float), 64); - - OriWeiT *fc3Weight = nullptr; - float *fc3Scales = nullptr; - float *fc3Zeros = nullptr; - - // INT8/INT4 quant, wbits = 8/4, qweight dtype: int8_t/uint4x2_t - if constexpr (std::is_same_v || std::is_same_v) { - DataType dt = std::is_same_v ? DataType::int8 : DataType::int4; - - qkvZeros = (float *)ALLOC(qkvSize * sizeof(float), 64); - qkvScales = (float *)ALLOC(qkvSize * sizeof(float), 64); - attnOutZeros = (float *)ALLOC(hiddenSize * sizeof(float), 64); - attnOutScales = (float *)ALLOC(hiddenSize * sizeof(float), 64); - fc1Zeros = (float *)ALLOC(imSize * mlpFactor * sizeof(float), 64); - fc1Scales = (float *)ALLOC(imSize * mlpFactor * sizeof(float), 64); - fc2Zeros = (float *)ALLOC(imSize * sizeof(float), 64); - fc2Scales = (float *)ALLOC(imSize * sizeof(float), 64); - - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) - + ".attention.query_key_value.qweight.0.bin", - qkvWeight, hiddenSize * qkvSize, dt); - loadWeight( - modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.zeros.0.bin", - qkvZeros, qkvSize, DataType::fp32); - loadWeight( - modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.scales.0.bin", - qkvScales, qkvSize, DataType::fp32); - - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.qweight.0.bin", - attnOutWeight, qSize * hiddenSize, dt); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.zeros.0.bin", - attnOutZeros, hiddenSize, DataType::fp32); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.scales.0.bin", - attnOutScales, hiddenSize, DataType::fp32); - - // Stardard 2 layer MLP - if (fileExists( - modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.qweight.0.bin")) { - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.qweight.0.bin", - fc1Weight, hiddenSize * imSize * mlpFactor, dt); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.zeros.0.bin", - fc1Zeros, imSize * mlpFactor, DataType::fp32); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.scales.0.bin", - fc1Scales, imSize * mlpFactor, DataType::fp32); - - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.qweight.0.bin", - fc2Weight, hiddenSize * imSize, dt); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.zeros.0.bin", - fc2Zeros, hiddenSize, DataType::fp32); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.scales.0.bin", - fc2Scales, hiddenSize, DataType::fp32); - } - // gate, up, down weights for Llama like model - else { - fc3Weight = (OriWeiT *)ALLOC(hiddenSize * imSize * sizeof(OriWeiT), 64); - fc3Zeros = (float *)ALLOC(hiddenSize * sizeof(float), 64); - fc3Scales = (float *)ALLOC(hiddenSize * sizeof(float), 64); - - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.qweight.0.bin", - fc1Weight, hiddenSize * imSize * mlpFactor, dt); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.zeros.0.bin", - fc1Zeros, imSize * mlpFactor, DataType::fp32); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.scales.0.bin", - fc1Scales, imSize * mlpFactor, DataType::fp32); - - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.qweight.0.bin", - fc2Weight, hiddenSize * imSize, dt); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.zeros.0.bin", - fc2Zeros, imSize, DataType::fp32); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.scales.0.bin", - fc2Scales, imSize, DataType::fp32); - - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.qweight.0.bin", - fc3Weight, hiddenSize * imSize, dt); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.zeros.0.bin", - fc3Zeros, hiddenSize, DataType::fp32); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.scales.0.bin", - fc3Scales, hiddenSize, DataType::fp32); - } - - } else if constexpr (std::is_same_v) { - loadWeight( - modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.weight.0.bin", - qkvWeight, hiddenSize * qkvSize); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.weight.0.bin", - attnOutWeight, qSize * hiddenSize); - - // Stardard 2 layer MLP - if (fileExists( - modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.weight.0.bin")) { - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.weight.0.bin", - fc1Weight, hiddenSize * imSize * mlpFactor); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.weight.0.bin", - fc2Weight, hiddenSize * imSize); - } - // gate, up, down weights for Llama like model - else { - fc3Weight = (OriWeiT *)ALLOC(hiddenSize * imSize * sizeof(OriWeiT), 64); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.weight.0.bin", - fc1Weight, hiddenSize * imSize * mlpFactor); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.weight.0.bin", - fc2Weight, hiddenSize * imSize); - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.weight.0.bin", - fc3Weight, hiddenSize * imSize); - } + int groupsize = getGroupSize(modelPath + "config.ini"); + + Weight qkvWeight, attnOutWeight, fc1Weight, fc2Weight, fc3Weight; + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.weight.0.bin", + qkvWeight, hiddenSize, qkvSize); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.weight.0.bin", + attnOutWeight, qSize, hiddenSize); + + bool standard_mlp = (fileExists( modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.weight.0.bin") + || fileExists( modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.qweight.0.bin")); + // Stardard 2 layer MLP + if (standard_mlp) { + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.weight.0.bin", + fc1Weight, hiddenSize, imSize * mlpFactor); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.weight.0.bin", + fc2Weight, imSize, hiddenSize); + } + // gate, up, down weights for Llama like model + else { + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.weight.0.bin", + fc1Weight, hiddenSize, imSize * mlpFactor); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.weight.0.bin", + fc2Weight, hiddenSize, imSize); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.weight.0.bin", + fc3Weight, imSize, hiddenSize); } - loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".input_layernorm.weight.bin", + float *qkvBias = nullptr, *attnOutBias = nullptr, *fc1Bias = nullptr, *fc2Bias = nullptr; + float *ln1Gamma = nullptr, *ln1Beta = nullptr, *ln2Gamma = nullptr, *ln2Beta = nullptr; + + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".input_layernorm.weight.bin", ln1Gamma, hiddenSize); - loadWeight( + loadWeight( modelPath + "/model.layers." + std::to_string(layerIdx) + ".post_attention_layernorm.weight.bin", ln2Gamma, hiddenSize); -#define READ_OPTIONAL(filename, addr, size, errmsg) \ +#define READ_OPTIONAL(filename, addr, size) \ { \ - int ret = loadWeight((filename), (addr), (size), DataType::unknown, false); \ - if (ret == 0) { \ - free(addr); \ - addr = nullptr; \ - } else { \ - if (ret != (size)) { \ - printf("%s\n", (errmsg)); \ - exit(-1); \ - } \ + if (fileExists(filename)) { \ + loadWeight((filename), (addr), (size)); \ } \ } // The bias is optional READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.bias.0.bin", - qkvBias, qkvSize, "read QKV bias error"); + qkvBias, qkvSize); READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.bias.bin", - attnOutBias, hiddenSize, "read attn dense bias error"); + attnOutBias, hiddenSize); READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".input_layernorm.bias.bin", ln1Beta, - hiddenSize, "read LN1 beta error"); + hiddenSize); READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".post_attention_layernorm.bias.bin", - ln2Beta, hiddenSize, "read LN2 beta error"); + ln2Beta, hiddenSize); READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.bias.0.bin", - fc1Bias, imSize, "read FC1 bias error"); + fc1Bias, imSize); READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.bias.bin", fc2Bias, - hiddenSize, "read FC2 bias error"); + hiddenSize); constexpr int sizeFactor = std::is_same_v ? 2 : 1; - pdecoder->setWeights(ctx, qkvWeight, qkvScales, qkvZeros, qkvBias, qkvWeight + qSize / sizeFactor, - qkvScales + qSize, qkvZeros + qSize, qkvBias + qSize, - qkvWeight + qSize / sizeFactor + kvSize / sizeFactor, qkvScales + qSize + kvSize, - qkvZeros + qSize + kvSize, qkvBias + qSize + kvSize, attnOutWeight, attnOutScales, attnOutZeros, - attnOutBias, ln1Gamma, ln1Beta, fc1Weight, fc1Scales, fc1Zeros, fc1Bias, fc2Weight, fc2Scales, fc2Zeros, - fc2Bias, ln2Gamma, ln2Beta, fc3Weight, fc3Scales, fc3Zeros, false); + pdecoder->setWeights(ctx, qkvWeight.w, qkvWeight.s, qkvWeight.z, qkvBias, qkvWeight.w + qSize / sizeFactor, + qkvWeight.s + qSize, qkvWeight.z + qSize, qkvBias + qSize, + qkvWeight.w + qSize / sizeFactor + kvSize / sizeFactor, qkvWeight.s + qSize + kvSize, + qkvWeight.z + qSize + kvSize, qkvBias + qSize + kvSize, attnOutWeight.w, attnOutWeight.s, attnOutWeight.z, + attnOutBias, ln1Gamma, ln1Beta, fc1Weight.w, fc1Weight.s, fc1Weight.z, fc1Bias, fc2Weight.w, fc2Weight.s, fc2Weight.z, + fc2Bias, ln2Gamma, ln2Beta, fc3Weight.w, fc3Weight.s, fc3Weight.z, false); - free(qkvWeight); - free(attnOutWeight); - free(fc1Weight); - free(fc2Weight); - free(fc3Weight); - free(qkvZeros); - free(attnOutZeros); - free(fc1Zeros); - free(fc2Zeros); - free(fc3Zeros); - free(qkvScales); - free(attnOutScales); - free(fc1Scales); - free(fc2Scales); - free(fc3Scales); free(qkvBias); free(attnOutBias); free(fc1Bias); diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index b118b5fb..7dc4ed4f 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -60,8 +60,8 @@ class DistLinear { int K = inputSize; int N = this->splitSize; - scaleWeight.Resize(N); - zeroWeight.Resize(N); + scaleWeight.Resize(1, N); + zeroWeight.Resize(1, N); xft::Matrix quantizedWeight; ctx->mmHelper->convertWeight( @@ -120,8 +120,8 @@ class DistLinear { int splitOffset; xft::Matrix weight; - xft::Vector scaleWeight; // if weight is int8 - xft::Vector zeroWeight; // if weight is int8 + xft::Matrix scaleWeight; // if weight is int8 + xft::Matrix zeroWeight; // if weight is int8 xft::Vector sumWeight; // if weight is int8 float *bias = nullptr; }; diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 334644bb..ee2caf80 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -326,13 +326,13 @@ class LlamaMLP { } void catGateUpWeights(xft::Matrix &gateWeight, xft::Matrix &upWeight, - xft::Vector &gateWeightScale, xft::Vector &gateWeightZero, xft::Vector &gateWeightSum, - xft::Vector &upWeightScale, xft::Vector &upWeightZero, xft::Vector &upWeightSum, - xft::Matrix &catWeights, xft::Vector &catWeightsScale, xft::Vector &catWeightsZero, + xft::Matrix &gateWeightScale, xft::Matrix &gateWeightZero, xft::Vector &gateWeightSum, + xft::Matrix &upWeightScale, xft::Matrix &upWeightZero, xft::Vector &upWeightSum, + xft::Matrix &catWeights, xft::Matrix &catWeightsScale, xft::Matrix &catWeightsZero, xft::Vector &catWeightsSum) { catWeights.Resize(gateWeight.Rows(), gateWeight.Cols() + upWeight.Cols()); - catWeightsScale.Resize(gateWeightScale.Size() + upWeightScale.Size()); - catWeightsZero.Resize(gateWeightZero.Size() + upWeightZero.Size()); + catWeightsScale.Resize(gateWeightScale.Rows(), gateWeightScale.Cols() + upWeightScale.Cols()); + catWeightsZero.Resize(gateWeightZero.Rows(), gateWeightZero.Cols() + upWeightZero.Cols()); catWeightsSum.Resize(gateWeightSum.Size() + upWeightSum.Size()); int M = catWeights.Rows(); @@ -349,12 +349,15 @@ class LlamaMLP { memcpy(catWeights.Data() + i * Stride + N, upWeight.Data() + i * N, N * sizeof(WeiT)); } - M = gateWeightScale.Size(); - N = upWeightScale.Size(); - memcpy(catWeightsScale.Data(), gateWeightScale.Data(), M * sizeof(float)); - memcpy(catWeightsScale.Data() + M, upWeightScale.Data(), N * sizeof(float)); - memcpy(catWeightsZero.Data(), gateWeightZero.Data(), M * sizeof(float)); - memcpy(catWeightsZero.Data() + M, upWeightZero.Data(), N * sizeof(float)); + M = gateWeightScale.Rows(); + Stride = catWeightsScale.Cols(); + N = gateWeightScale.Cols(); + for (uint64_t i = 0; i < M; ++i) { + memcpy(catWeightsScale.Data() + i * Stride, gateWeightScale.Data() + i * N, N * sizeof(float)); + memcpy(catWeightsScale.Data() + i * Stride + N, upWeightScale.Data() + i * N, N * sizeof(float)); + memcpy(catWeightsZero.Data() + i * Stride, gateWeightZero.Data() + i * N, N * sizeof(float)); + memcpy(catWeightsZero.Data() + i * Stride + N, upWeightZero.Data() + i * N, N * sizeof(float)); + } M = gateWeightSum.Size(); N = upWeightSum.Size(); memcpy(catWeightsSum.Data(), gateWeightSum.Data(), M * sizeof(float)); @@ -363,20 +366,20 @@ class LlamaMLP { protected: xft::Matrix gateWeight; - xft::Vector gateWeightScale; // For int8_t weight - xft::Vector gateWeightZero; // For int8_t weight + xft::Matrix gateWeightScale; // For int8_t weight + xft::Matrix gateWeightZero; // For int8_t weight xft::Vector gateWeightSum; // For int8_t weight xft::Matrix upWeight; - xft::Vector upWeightScale; // For int8_t weight - xft::Vector upWeightZero; // For int8_t weight + xft::Matrix upWeightScale; // For int8_t weight + xft::Matrix upWeightZero; // For int8_t weight xft::Vector upWeightSum; // For int8_t weight xft::Matrix catWeights; - xft::Vector catWeightsScale; // For int8_t weight - xft::Vector catWeightsZero; // For int8_t weight + xft::Matrix catWeightsScale; // For int8_t weight + xft::Matrix catWeightsZero; // For int8_t weight xft::Vector catWeightsSum; // For int8_t weight xft::Matrix downWeight; - xft::Vector downWeightScale; // For int8_t weight - xft::Vector downWeightZero; // For int8_t weight + xft::Matrix downWeightScale; // For int8_t weight + xft::Matrix downWeightZero; // For int8_t weight xft::Vector downWeightSum; // For int8_t weight // LlamaRMSNorm param diff --git a/src/layers/mlp_standard.h b/src/layers/mlp_standard.h index a4d65170..e44c2aea 100644 --- a/src/layers/mlp_standard.h +++ b/src/layers/mlp_standard.h @@ -225,14 +225,14 @@ class MLP { // private: xft::Matrix intermediateWeight; - xft::Vector intermediateWeightScale; - xft::Vector intermediateWeightZero; + xft::Matrix intermediateWeightScale; + xft::Matrix intermediateWeightZero; xft::Vector intermediateWeightSum; xft::Vector intermediateBias; xft::Matrix outputWeight; - xft::Vector outputWeightScale; - xft::Vector outputWeightZero; + xft::Matrix outputWeightScale; + xft::Matrix outputWeightZero; xft::Vector outputWeightSum; xft::Vector outputBias; diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 8ca71991..2e0eb82f 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -216,7 +216,6 @@ class CommonDecoder : public AbstractDecoder { dt = quantQweightDataType == "int8" ? DataType::int8 : DataType::int4; REQUIRES(quantScalesDataType == "fp32", "scales should be fp32 data type."); REQUIRES(quantZerosDataType == "fp32", "zeros should be fp32 data type."); - REQUIRES(quantGroupsize == -1, "Quantization with groupsize is not supported."); } // Buffer related (not initialized) @@ -228,7 +227,7 @@ class CommonDecoder : public AbstractDecoder { // Context DecoderContext *ctx = getDecoderContext(layers, hiddenSize, headSize, attHeadNum, kvHeadNum, imSize, act, epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, useLogN, useNTK, - ropeParamsPtr); + ropeParamsPtr, quantGroupsize); ctx->ResetConfigReader(configPath); @@ -738,7 +737,7 @@ class CommonDecoder : public AbstractDecoder { DecoderContext *getDecoderContext(int layers, const int hiddenSize, const int headSize, const int attHeadNum, const int kvHeadNum, const int imSize, const std::string &act, const float epsilon, int vocabSize, int embeddingSize, int maxPositions, int maxPosEmbed, int maxSeqLength, bool useLogN, bool useNTK, - RopeParams *ropeParamsPtr) { + RopeParams *ropeParamsPtr, int groupsize) { Env &env = Env::getInstance(); int tpSize = messenger.getSize(); int tpRank = messenger.getRank(); @@ -769,7 +768,7 @@ class CommonDecoder : public AbstractDecoder { #endif this->context.reset(new DecoderContext(layers, hiddenSize, headSize, attHeadNum, kvHeadNum, imSize, act, epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, - this->mmHelper.get(), this->device.get(), ppSize, ppRank, ropeParamsPtr, useLogN, useNTK)); + this->mmHelper.get(), this->device.get(), ppSize, ppRank, ropeParamsPtr, useLogN, useNTK, 0, groupsize)); } return this->context.get(); diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 536160a0..6d1010c4 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -107,7 +107,7 @@ class MMHelper { template void convertWeight(bool trans, int rows, int cols, const OriWeiT *weight, const float *scales, const float *zeros, int splitOffset, int splitSize, bool verticalSplit, xft::Matrix &convertedWeight, - xft::Vector &scaleWeight, xft::Vector &zeroWeight, xft::Vector &sumWeight, + xft::Matrix &scaleWeight, xft::Matrix &zeroWeight, xft::Vector &sumWeight, int groupsize, bool unused) { // transform trans cases to no trans cases if (trans) { @@ -163,8 +163,8 @@ class MMHelper { // FP32 -> INT8/W8A8 else if constexpr (std::is_same_v && (std::is_same_v || std::is_same_v)) { - scaleWeight.Resize(trans ? rowSize : colSize); - zeroWeight.Resize(trans ? rowSize : colSize); + scaleWeight.Resize(1, trans ? rowSize : colSize); + zeroWeight.Resize(1, trans ? rowSize : colSize); const float *src = weight + rowOffset * cols + colOffset; #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 xdnn_sgemm_f32s8f32_quantize(trans, trans ? rowSize : colSize, trans ? colSize : rowSize, src, cols, @@ -183,8 +183,8 @@ class MMHelper { // FP32 -> UINT4 else if constexpr (std::is_same_v && std::is_same_v) { - scaleWeight.Resize(trans ? rowSize : colSize); - zeroWeight.Resize(trans ? rowSize : colSize); + scaleWeight.Resize(1, trans ? rowSize : colSize); + zeroWeight.Resize(1, trans ? rowSize : colSize); const float *src = weight + rowOffset * cols + colOffset; #ifdef AVX512_FP32_WEIGHT_ONLY_INT4 xdnn_sgemm_f32u4f32_quantize(trans, trans ? rowSize : colSize, trans ? colSize : rowSize, src, cols, @@ -202,8 +202,8 @@ class MMHelper { // FP32 -> NF4 else if constexpr (std::is_same_v && std::is_same_v) { - scaleWeight.Resize(trans ? rowSize : colSize); - zeroWeight.Resize(trans ? rowSize : colSize); + scaleWeight.Resize(1, trans ? rowSize : colSize); + zeroWeight.Resize(1, trans ? rowSize : colSize); const float *src = weight + rowOffset * cols + colOffset; #ifdef AVX512_FP32_WEIGHT_ONLY_NF4 xdnn_sgemm_f32nf4f32_quantize(trans, trans ? rowSize : colSize, trans ? colSize : rowSize, src, cols, @@ -222,12 +222,15 @@ class MMHelper { // INT8 -> INT8/W8A8 else if constexpr (std::is_same_v && (std::is_same_v || std::is_same_v)) { - int size = trans ? rowSize : colSize; - int offset = trans ? rowOffset : colOffset; - scaleWeight.Resize(size); - zeroWeight.Resize(size); - if (scales) memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); - if (zeros) memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); + int size = colSize; + int groups = groupsize == -1 ? 1 : rowSize / groupsize; + int offset = colOffset; + scaleWeight.Resize(groups, size); + zeroWeight.Resize(groups, size); + for (int i = 0; i < groups; i++) { + memcpy(scaleWeight.Data() + i * size, scales + i * cols + offset, size * sizeof(float)); + memcpy(zeroWeight.Data() + i * size, zeros + i * cols + offset, size * sizeof(float)); + } #pragma omp parallel for for (uint64_t i = 0; i < rowSize; i++) { WeiT *dst = convertedWeight.Data() + i * convertedWeight.Stride(); @@ -238,12 +241,15 @@ class MMHelper { // UINT4 -> UINT4 else if constexpr (std::is_same_v && std::is_same_v) { - int size = trans ? rowSize : colSize; - int offset = trans ? rowOffset : colOffset; - scaleWeight.Resize(size); - zeroWeight.Resize(size); - if (scales) memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); - if (zeros) memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); + int size = colSize; + int groups = groupsize == -1 ? 1 : rowSize / groupsize; + int offset = colOffset; + scaleWeight.Resize(groups, size); + zeroWeight.Resize(groups, size); + for (int i = 0; i < groups; i++) { + memcpy(scaleWeight.Data() + i * size, scales + i * cols + offset, size * sizeof(float)); + memcpy(zeroWeight.Data() + i * size, zeros + i * cols + offset, size * sizeof(float)); + } #pragma omp parallel for for (uint64_t i = 0; i < rowSize; i++) { WeiT *dst = convertedWeight.Data() + i * convertedWeight.Stride() / 2; @@ -304,7 +310,7 @@ class MMHelper { template void convertWeight(bool trans, int rows, int cols, const OriWeiT *weight, const float *scales, const float *zeros, int numSplit, int splitIdx, bool verticalSplit, xft::Matrix &quantizedWeight, - xft::Vector &scaleWeight, xft::Vector &zeroWeight, xft::Vector &sumWeight) { + xft::Matrix &scaleWeight, xft::Matrix &zeroWeight, xft::Vector &sumWeight, int groupsize = -1) { int totalSize = verticalSplit ? cols : rows; std::pair range = SplitUtil::getTaskRange(totalSize, numSplit, splitIdx); @@ -312,23 +318,23 @@ class MMHelper { int splitOffset = range.first; convertWeight(trans, rows, cols, weight, scales, zeros, splitOffset, splitSize, verticalSplit, quantizedWeight, - scaleWeight, zeroWeight, sumWeight, true); + scaleWeight, zeroWeight, sumWeight, groupsize, true); } template void convertWeight(bool trans, int rows, int cols, const OriWeiT *weight, const float *scales, const float *zeros, - xft::Matrix &quantizedWeight, xft::Vector &scaleWeight, xft::Vector &zeroWeight, - xft::Vector &sumWeight) { + xft::Matrix &quantizedWeight, xft::Matrix &scaleWeight, xft::Matrix &zeroWeight, + xft::Vector &sumWeight, int groupsize = -1) { convertWeight(trans, rows, cols, weight, scales, zeros, 1, 0, true, quantizedWeight, scaleWeight, zeroWeight, - sumWeight); + sumWeight, groupsize); } template void convertWeight(DecoderContext *ctx, bool trans, int rows, int cols, const OriWeiT *weight, const float *scales, - const float *zeros, bool verticalSplit, xft::Matrix &quantizedWeight, xft::Vector &scaleWeight, - xft::Vector &zeroWeight, xft::Vector &sumWeight) { + const float *zeros, bool verticalSplit, xft::Matrix &quantizedWeight, xft::Matrix &scaleWeight, + xft::Matrix &zeroWeight, xft::Vector &sumWeight) { convertWeight(trans, rows, cols, weight, scales, zeros, ctx->numSplit, ctx->splitIdx, verticalSplit, - quantizedWeight, scaleWeight, zeroWeight, sumWeight); + quantizedWeight, scaleWeight, zeroWeight, sumWeight, ctx->groupsize); } template diff --git a/src/utils/weight_util.h b/src/utils/weight_util.h index 0ac594c1..12177104 100644 --- a/src/utils/weight_util.h +++ b/src/utils/weight_util.h @@ -32,6 +32,47 @@ namespace xft { +inline bool fileExists(const std::string &filename) { + std::ifstream file(filename); + return file.good(); +} + +template +struct Weight { + WeiT* w; + float* s; + float* z; + + Weight() { + w = nullptr; + s = nullptr; + z = nullptr; + } + Weight (const Weight&) = delete; + Weight& operator= (const Weight&) = delete; + ~Weight() { + free(w); + free(s); + free(z); + } +}; + +inline int getGroupSize(const std::string &ini_file, std::string section_name = "") { + INIReader reader = INIReader(ini_file); + if (reader.ParseError() == 0 ) { + return -1; + } else { + if (section_name == "") { + if (!reader.Sections().empty()) { + section_name = *(reader.Sections().begin()); + } else { + return -1; + } + } + return reader.GetInteger(section_name, "quant_groupsize", -1); + } +} + inline DataType getWeightType(const std::string &ini_file, std::string section_name = "") { DataType w_type; INIReader reader = INIReader(ini_file); @@ -108,13 +149,13 @@ int loadWeightWithConvert(T *ptr, int size, const std::string &filename, bool re if constexpr (std::is_same_v == true) { // If T and WT are the same, directly read the file file_size = readFile(filename, ptr, size); - if (required) REQUIRES(file_size == size, "read %s failed!", filename.c_str()); + REQUIRES(file_size == size, "read %s failed!", filename.c_str()); } else { // If T and WT are different types, perform dynamic type conversion WT *w_ptr = nullptr; w_ptr = (WT *)xft::alloc(sizeof(WT) * size); file_size = readFile(filename, w_ptr, size); - if (required) REQUIRES(file_size == size, "read %s failed!", filename.c_str()); + REQUIRES(file_size == size, "read %s failed!", filename.c_str()); if constexpr ((std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v) @@ -154,7 +195,6 @@ int loadWeightWithConvert(T *ptr, int size, const std::string &filename, bool re template int loadWeight(std::string filename, T *&ptr, int size, DataType w_type = DataType::unknown, bool required = true) { - // By default, read the config.ini configuration file // in the same directory as the model file to determine the data type of the file. if (w_type == DataType::unknown) { @@ -163,8 +203,6 @@ int loadWeight(std::string filename, T *&ptr, int size, DataType w_type = DataTy std::string configFilePath = dirPath + "/config.ini"; w_type = getWeightType(configFilePath); } - //1 uint4x2 stores 2 uint4 value, so load size is halfed. - if constexpr (std::is_same_v) { size = size / 2; } if (!ptr) { ptr = (T *)xft::alloc(size * sizeof(T)); } int file_size = 0; switch (w_type) { @@ -178,6 +216,46 @@ int loadWeight(std::string filename, T *&ptr, int size, DataType w_type = DataTy return file_size; } +// load weight or qweight/scales/zeros to Weight according to config.ini +template +void loadWeight(std::string filename, Weight &w, int rows, int cols) { + int size = rows * cols; + + if constexpr (std::is_same_v || std::is_same_v) { + // INT8/INT4 quant, wbits = 8/4, qweight dtype: int8_t/uint4x2_t + std::size_t pos = filename.find_last_of("/\\"); + std::string dirPath = filename.substr(0, pos); + std::string configFilePath = dirPath + "/config.ini"; + int groupsize = getGroupSize(configFilePath); + int groups = groupsize == -1 ? 1 : rows / groupsize; + + auto replace = [](std::string s, std::string a, std::string b) { return s.replace(s.find(a), a.length(), b);}; + + DataType t = DataType::int8; + if constexpr (std::is_same_v) { + t = DataType::int4; + + //1 uint4x2 stores 2 uint4 value, so load size is halfed. + size = size / 2; + } + + auto qweightName = replace(filename, "weight", "qweight"); + w.w = (T *)xft::alloc(size * sizeof(T)); + loadWeight(qweightName, w.w, size, t); + + auto scalesName = replace(filename, "weight", "scales"); + w.s = (float *)xft::alloc(cols * sizeof(float)); + loadWeight(scalesName, w.s, groups * cols, DataType::fp32); + + auto zerosName = replace(filename, "weight", "zeros"); + w.z = (float *)xft::alloc(cols * sizeof(float)); + loadWeight(zerosName, w.z, groups * cols, DataType::fp32); + } else { + w.w = (T *)xft::alloc(size * sizeof(T)); + loadWeight(filename, w.w, size); + } +} + template int loadWeightWithConvert(float *, int, const std::string &, bool); template int loadWeightWithConvert(float16_t *, int, const std::string &, bool); template int loadWeightWithConvert(bfloat16_t *, int, const std::string &, bool); diff --git a/src/xfastertransformer/tools/qwen2_convert.py b/src/xfastertransformer/tools/qwen2_convert.py index 63ca2a87..c59dbdcf 100644 --- a/src/xfastertransformer/tools/qwen2_convert.py +++ b/src/xfastertransformer/tools/qwen2_convert.py @@ -311,7 +311,6 @@ def split_and_convert_quantized_model(self, input_dir, output_dir, dtype, proces config[sec_name]["quant_qweight_data_type"] = "int8" if self.wbits == 8 else "uint4" config[sec_name]["quant_scales_data_type"] = "fp32" config[sec_name]["quant_zeros_data_type"] = "fp32" - assert quantize_config["group_size"] == -1, "Only column wise quantization is supported." config[sec_name]["quant_groupsize"] = str(quantize_config["group_size"]) # config[sec-name]["quant_scheme"] = "sym" if quantize_config["sym"] == True else "asym" @@ -409,7 +408,7 @@ def split_and_convert_quantized_model(self, input_dir, output_dir, dtype, proces # for uint4, zeros = - scales * qzeros if self.wbits == 8: qzeros = qzeros - 128 # uint8 to int8 - qzeros = torch.flatten(qzeros).float() + qzeros = qzeros.reshape(qzeros.shape[0], -1).float() scales = state_dict["model." + name.replace("qzeros", "scales")].float() zeros = -scales * qzeros model_named_parameters[name] = zeros From d8ddd8429a916cb2bdf93b45aac2af0d301973f9 Mon Sep 17 00:00:00 2001 From: Zhang Xiangze Date: Tue, 25 Jun 2024 16:03:17 +0530 Subject: [PATCH 2/3] Bug fix --- src/utils/weight_util.h | 6 +++--- src/xfastertransformer/tools/llama_convert.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/utils/weight_util.h b/src/utils/weight_util.h index 12177104..5a01df86 100644 --- a/src/utils/weight_util.h +++ b/src/utils/weight_util.h @@ -59,7 +59,7 @@ struct Weight { inline int getGroupSize(const std::string &ini_file, std::string section_name = "") { INIReader reader = INIReader(ini_file); - if (reader.ParseError() == 0 ) { + if (reader.ParseError() != 0 ) { return -1; } else { if (section_name == "") { @@ -244,11 +244,11 @@ void loadWeight(std::string filename, Weight &w, int rows, int cols) { loadWeight(qweightName, w.w, size, t); auto scalesName = replace(filename, "weight", "scales"); - w.s = (float *)xft::alloc(cols * sizeof(float)); + w.s = (float *)xft::alloc(groups * cols * sizeof(float)); loadWeight(scalesName, w.s, groups * cols, DataType::fp32); auto zerosName = replace(filename, "weight", "zeros"); - w.z = (float *)xft::alloc(cols * sizeof(float)); + w.z = (float *)xft::alloc(groups * cols * sizeof(float)); loadWeight(zerosName, w.z, groups * cols, DataType::fp32); } else { w.w = (T *)xft::alloc(size * sizeof(T)); diff --git a/src/xfastertransformer/tools/llama_convert.py b/src/xfastertransformer/tools/llama_convert.py index e62b28c9..30a1c5b3 100644 --- a/src/xfastertransformer/tools/llama_convert.py +++ b/src/xfastertransformer/tools/llama_convert.py @@ -286,7 +286,6 @@ def split_and_convert_quantized_model(self, input_dir, output_dir, dtype, proces config["llama"]["quant_qweight_data_type"] = "int8" if self.wbits == 8 else "uint4" config["llama"]["quant_scales_data_type"] = "fp32" config["llama"]["quant_zeros_data_type"] = "fp32" - assert quantize_config["group_size"] == -1, "Only column wise quantization is supported." config["llama"]["quant_groupsize"] = str(quantize_config["group_size"]) # config["llama"]["quant_scheme"] = "sym" if quantize_config["sym"] == True else "asym" @@ -365,7 +364,7 @@ def split_and_convert_quantized_model(self, input_dir, output_dir, dtype, proces # for uint4, zeros = - scales * qzeros if self.wbits == 8: qzeros = qzeros - 128 # uint8 to int8 - qzeros = torch.flatten(qzeros).float() + qzeros = qzeros.reshape(qzeros.shape[0], -1).float() scales = state_dict["model." + name.replace("qzeros", "scales")].float() zeros = -scales * qzeros model_named_parameters[name] = zeros From 8b56ed9674e552a61105a0cb25dae9d8b59a9b06 Mon Sep 17 00:00:00 2001 From: Zhang Xiangze Date: Wed, 26 Jun 2024 20:04:08 +0530 Subject: [PATCH 3/3] Change xdnn api --- src/layers/attention.h | 24 +++++------ src/layers/dist_linear.h | 4 +- src/layers/mlp_llama.h | 14 +++---- src/layers/mlp_standard.h | 12 +++--- src/utils/matmul_helper.h | 87 +++++++++++++++++++++------------------ 5 files changed, 74 insertions(+), 67 deletions(-) diff --git a/src/layers/attention.h b/src/layers/attention.h index ad3544cf..63cdc458 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -293,11 +293,11 @@ class Attention { if (qkvBias.Size() == 0) { ctx->mmHelper->compute(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(), - qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride()); + qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), ctx->groupsize); } else { ctx->mmHelper->compute_bias(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(), - qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data()); + qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data(), ctx->groupsize); } t2.release(); @@ -409,26 +409,26 @@ class Attention { ctx->mmHelper->compute_residential(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, - outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride()); + outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize); } else { float *pbias = attnOutputBias.Data(); if (attnOutputBias.Size() == 0) { pbias = nullptr; } ctx->mmHelper->compute_resext(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(), - outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride()); + outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize); } } else { if (attnOutputBias.Size() == 0) { ctx->mmHelper->compute(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(), - outBuffer.Stride()); + outBuffer.Stride(), ctx->groupsize); } else { ctx->mmHelper->compute_bias(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(), - outBuffer.Stride(), attnOutputBias.Data()); + outBuffer.Stride(), attnOutputBias.Data(), ctx->groupsize); } } t5.release(); @@ -499,11 +499,11 @@ class Attention { if (qkvBias.Size() == 0) { ctx->mmHelper->compute(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(), - qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride()); + qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), ctx->groupsize); } else { ctx->mmHelper->compute_bias(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(), - qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data()); + qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data(), ctx->groupsize); } t2.release(); @@ -592,26 +592,26 @@ class Attention { ctx->mmHelper->compute_residential(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, - outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride()); + outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize); } else { float *pbias = attnOutputBias.Data(); if (attnOutputBias.Size() == 0) { pbias = nullptr; } ctx->mmHelper->compute_resext(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(), - outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride()); + outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize); } } else { if (attnOutputBias.Size() == 0) { ctx->mmHelper->compute(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(), - outBuffer.Stride()); + outBuffer.Stride(), ctx->groupsize); } else { ctx->mmHelper->compute_bias(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(), - outBuffer.Stride(), attnOutputBias.Data()); + outBuffer.Stride(), attnOutputBias.Data(), ctx->groupsize); } } t5.release(); diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index 7dc4ed4f..79625a2b 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -92,11 +92,11 @@ class DistLinear { TimeLine t("DistLinear.forward"); if (bias) { ctx->mmHelper->compute_bias(false, batchSize, splitSize, inputSize, 1.0f, input, inputSize, weight.Data(), - scaleWeight.Data(), zeroWeight.Data(), sumWeight.Data(), 0.0f, output, splitSize, bias); + scaleWeight.Data(), zeroWeight.Data(), sumWeight.Data(), 0.0f, output, splitSize, bias, ctx->groupsize); } else { ctx->mmHelper->compute(false, batchSize, splitSize, inputSize, 1.0f, input, inputSize, weight.Data(), - scaleWeight.Data(), zeroWeight.Data(), sumWeight.Data(), 0.0f, output, splitSize); + scaleWeight.Data(), zeroWeight.Data(), sumWeight.Data(), 0.0f, output, splitSize, ctx->groupsize); } } diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index ee2caf80..c280ab7a 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -234,11 +234,11 @@ class LlamaMLP { ImT *C = output.Data(); if (ctx->actType == DecoderContext::SILU) { - ctx->mmHelper->compute_silu(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc); + ctx->mmHelper->compute_silu(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, ctx->groupsize); } else if (ctx->actType == DecoderContext::SWIGLU) { // chatglm2/3 - ctx->mmHelper->compute_silu(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc); + ctx->mmHelper->compute_silu(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, ctx->groupsize); } else if (ctx->actType == DecoderContext::GELU) { // gemma - ctx->mmHelper->compute_gelu(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc); + ctx->mmHelper->compute_gelu(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, ctx->groupsize); } else { printf("ERROR: unsupported activation in MLP.\n"); exit(-1); @@ -262,7 +262,7 @@ class LlamaMLP { const float *sumB = upWeightSum.Data(); ImT *C = output.Data(); - ctx->mmHelper->compute_resmul(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, C, ldc); + ctx->mmHelper->compute_resmul(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, C, ldc, ctx->groupsize); } void downProj(DecoderContext *ctx, xft::Matrix &input, xft::Matrix &output, @@ -286,9 +286,9 @@ class LlamaMLP { if (isMaster) { ctx->mmHelper->compute_residential( - false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, NULL, R, ldr); + false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, NULL, R, ldr, ctx->groupsize); } else { - ctx->mmHelper->compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc); + ctx->mmHelper->compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, ctx->groupsize); } } @@ -310,7 +310,7 @@ class LlamaMLP { const float *sumB = catWeightsSum.Data(); T2 *C = output.Data(); - ctx->mmHelper->compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc); + ctx->mmHelper->compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, ctx->groupsize); // Compute silu on the left half and then add it with the right half if (ctx->actType == DecoderContext::SILU) { diff --git a/src/layers/mlp_standard.h b/src/layers/mlp_standard.h index e44c2aea..b214d755 100644 --- a/src/layers/mlp_standard.h +++ b/src/layers/mlp_standard.h @@ -126,26 +126,26 @@ class MLP { ctx->mmHelper->compute_residential(false, imBuffer.Rows(), outputWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(), imBuffer.Stride(), outputWeight.Data(), outputWeightScale.Data(), outputWeightZero.Data(), outputWeightSum.Data(), 0.0f, resultBuffer1.Data(), - resultBuffer1.Stride(), pbias, resultBuffer2.Data(), resultBuffer2.Stride()); + resultBuffer1.Stride(), pbias, resultBuffer2.Data(), resultBuffer2.Stride(), ctx->groupsize); } else { float *pbias = outputBias.Data(); if (outputBias.Size() == 0) { pbias = nullptr; } ctx->mmHelper->compute_resext(false, imBuffer.Rows(), outputWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(), imBuffer.Stride(), outputWeight.Data(), outputWeightScale.Data(), outputWeightZero.Data(), outputWeightSum.Data(), 0.0f, resultBuffer1.Data(), - resultBuffer1.Stride(), pbias, gamma, resultBuffer2.Data(), resultBuffer2.Stride()); + resultBuffer1.Stride(), pbias, gamma, resultBuffer2.Data(), resultBuffer2.Stride(), ctx->groupsize); } } else { if (outputBias.Size() == 0) { ctx->mmHelper->compute(false, imBuffer.Rows(), outputWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(), imBuffer.Stride(), outputWeight.Data(), outputWeightScale.Data(), outputWeightZero.Data(), outputWeightSum.Data(), 0.0f, resultBuffer1.Data(), - resultBuffer1.Stride()); + resultBuffer1.Stride(), ctx->groupsize); } else { ctx->mmHelper->compute_bias(false, imBuffer.Rows(), outputWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(), imBuffer.Stride(), outputWeight.Data(), outputWeightScale.Data(), outputWeightZero.Data(), outputWeightSum.Data(), 0.0f, resultBuffer1.Data(), - resultBuffer1.Stride(), outputBias.Data()); + resultBuffer1.Stride(), outputBias.Data(), ctx->groupsize); } } @@ -168,13 +168,13 @@ class MLP { ctx->mmHelper->compute_biasadd_relu(false, input.Rows(), output.Cols(), input.Cols(), 1.0f, input.Data(), input.Stride(), intermediateWeight.Data(), intermediateWeightScale.Data(), intermediateWeightZero.Data(), intermediateWeightSum.Data(), 0.0f, output.Data(), output.Stride(), - intermediateBias.Data()); + intermediateBias.Data(), ctx->groupsize); } void intermediate_gelu(DecoderContext *ctx, xft::Matrix &input, xft::Matrix &output) { ctx->mmHelper->compute(false, input.Rows(), output.Cols(), input.Cols(), 1.0f, input.Data(), input.Stride(), intermediateWeight.Data(), intermediateWeightScale.Data(), intermediateWeightZero.Data(), - intermediateWeightSum.Data(), 0.0f, output.Data(), output.Stride()); + intermediateWeightSum.Data(), 0.0f, output.Data(), output.Stride(), ctx->groupsize); float *pbias = intermediateBias.Data(); float factor = 0.7978845608; // np.sqrt(2 / np.pi) diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 6d1010c4..12bfad6d 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -262,11 +262,12 @@ class MMHelper { else if constexpr (std::is_same_v && std::is_same_v) { #pragma omp parallel for for (uint64_t i = 0; i < rowSize; i++) { + int group_offset = groupsize == -1 ? 0 : i / groupsize; for (uint64_t j = 0; j < colSize; j++) { const int8_t src = weight[(rowOffset + i) * cols + colOffset + j]; bfloat16_t *dst = convertedWeight.Data() + i * convertedWeight.Stride() + j; - float scale = scales[colOffset + j]; - float zero = zeros[colOffset + j]; + float scale = scales[group_offset * colSize + colOffset + j]; + float zero = zeros[group_offset * colSize + colOffset + j]; *dst = static_cast(scale * src + zero); } } @@ -276,13 +277,14 @@ class MMHelper { else if constexpr (std::is_same_v && std::is_same_v) { #pragma omp parallel for for (uint64_t i = 0; i < rowSize; i++) { + int group_offset = groupsize == -1 ? 0 : i / groupsize; for (uint64_t j = 0; j < colSize; j+=2) { const uint4x2_t *src = weight + (rowOffset + i) * cols / 2 + colOffset / 2 + j / 2; bfloat16_t *dst = convertedWeight.Data() + i * convertedWeight.Stride() + j; - float scale1 = scales[colOffset + j]; - float scale2 = scales[colOffset + j + 1]; - float zero1 = zeros[colOffset + j]; - float zero2 = zeros[colOffset + j + 1]; + float scale1 = scales[group_offset * colSize + colOffset + j]; + float scale2 = scales[group_offset * colSize + colOffset + j + 1]; + float zero1 = zeros[group_offset * colSize + colOffset + j]; + float zero2 = zeros[group_offset * colSize + colOffset + j + 1]; dst[0] = static_cast(scale1 * src->get_v1() + zero1); dst[1] = static_cast(scale2 * src->get_v2() + zero2); } @@ -475,7 +477,7 @@ class MMHelper { template void compute(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, - const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc) { + const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc, int groupsize) { // FP32 if constexpr (std::is_same_v) { GEMMVERBOSE( @@ -549,7 +551,7 @@ class MMHelper { xdnn_sgemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc)); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) GEMMVERBOSE("xdnn_hgemm_f32s8f32_compute", - xdnn_hgemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc)); + xdnn_hgemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -560,7 +562,7 @@ class MMHelper { else if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_gemm_f32s8f32_compute", onednn_amx_gemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, (const int8_t *)packedB, scaleB, - zeroB, sumB, beta, C, ldc, nullptr, nullptr, 0, 0.0f, matmul_kinds::Basic)); + zeroB, sumB, beta, C, ldc, nullptr, nullptr, 0, 0.0f, matmul_kinds::Basic, groupsize)); } // INT4 @@ -572,7 +574,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT4) GEMMVERBOSE("xdnn_hgemm_f32u4f32_compute", xdnn_hgemm_f32u4f32_compute(transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, scaleB, - zeroB, beta, C, ldc)); + zeroB, beta, C, ldc, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT4 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -599,7 +601,7 @@ class MMHelper { template void compute_bias(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc, - const float *bias) { + const float *bias, int groupsize) { // FP32 if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_sgemm_compute_biasadd", @@ -678,7 +680,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) GEMMVERBOSE("xdnn_hgemm_f32s8f32_compute_biasadd", xdnn_hgemm_f32s8f32_compute_biasadd( - transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias)); + transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -689,7 +691,7 @@ class MMHelper { else if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_gemm_f32s8f32_compute_biasadd", onednn_amx_gemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, (const int8_t *)packedB, scaleB, - zeroB, sumB, beta, C, ldc, bias, nullptr, 0, 0.0f, matmul_kinds::BiasAdd)); + zeroB, sumB, beta, C, ldc, bias, nullptr, 0, 0.0f, matmul_kinds::BiasAdd, groupsize)); } // INT4 @@ -701,7 +703,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT4) GEMMVERBOSE("xdnn_hgemm_f32u4f32_compute_biasadd", xdnn_hgemm_f32u4f32_compute_biasadd(transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, - scaleB, zeroB, beta, C, ldc, bias)); + scaleB, zeroB, beta, C, ldc, bias, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT4 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -728,7 +730,7 @@ class MMHelper { template void compute_biasadd_relu(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc, - const float *bias) { + const float *bias, int groupsize) { // FP32 if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_sgemm_compute_biasadd_relu", @@ -800,7 +802,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) GEMMVERBOSE("xdnn_hgemm_f32s8f32_compute_biasadd_relu", xdnn_hgemm_f32s8f32_compute_biasadd_relu( - transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias)); + transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -811,7 +813,7 @@ class MMHelper { else if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_gemm_f32s8f32_compute_biasadd_relu", onednn_amx_gemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, (const int8_t *)packedB, scaleB, - zeroB, sumB, beta, C, ldc, bias, nullptr, 0, 0.0f, matmul_kinds::BiasAdd_Relu)); + zeroB, sumB, beta, C, ldc, bias, nullptr, 0, 0.0f, matmul_kinds::BiasAdd_Relu, groupsize)); } // INT4 @@ -823,7 +825,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT4) GEMMVERBOSE("xdnn_hgemm_f32u4f32_compute_biasadd_relu", xdnn_hgemm_f32u4f32_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, - (const XDNN_UINT4x2 *)packedB, scaleB, zeroB, beta, C, ldc, bias)); + (const XDNN_UINT4x2 *)packedB, scaleB, zeroB, beta, C, ldc, bias, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT4 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -849,7 +851,7 @@ class MMHelper { template void compute_silu(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, - const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc) { + const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc, int groupsize) { // FP32 if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_sgemm_compute_silu", @@ -927,7 +929,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) GEMMVERBOSE("xdnn_hgemm_f32s8f32_compute_silu", xdnn_hgemm_f32s8f32_compute_silu( - transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc)); + transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -938,7 +940,7 @@ class MMHelper { else if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_gemm_f32s8f32_compute_silu", onednn_amx_gemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, (const int8_t *)packedB, scaleB, - zeroB, sumB, beta, C, ldc, nullptr, nullptr, 0, 0.0f, matmul_kinds::Silu)); + zeroB, sumB, beta, C, ldc, nullptr, nullptr, 0, 0.0f, matmul_kinds::Silu, groupsize)); } // INT4 @@ -946,11 +948,11 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_INT4 GEMMVERBOSE("xdnn_sgemm_f32u4f32_compute_silu", xdnn_sgemm_f32u4f32_compute_silu(transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, - scaleB, zeroB, beta, C, ldc)); + scaleB, zeroB, beta, C, ldc, groupsize)); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT4) GEMMVERBOSE("xdnn_hgemm_f32u4f32_compute_silu", xdnn_hgemm_f32u4f32_compute_silu(transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, - scaleB, zeroB, beta, C, ldc)); + scaleB, zeroB, beta, C, ldc, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT4 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -976,7 +978,7 @@ class MMHelper { template void compute_gelu(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, - const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc) { + const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc, int groupsize) { // FP32 if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_sgemm_compute_gelu", @@ -1055,7 +1057,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) GEMMVERBOSE("xdnn_hgemm_f32s8f32_compute_gelu", xdnn_hgemm_f32s8f32_compute_gelu( - transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc)); + transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -1066,7 +1068,7 @@ class MMHelper { else if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_gemm_f32s8f32_compute_gelu", onednn_amx_gemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, (const int8_t *)packedB, scaleB, - zeroB, sumB, beta, C, ldc, nullptr, nullptr, 0, 0.0f, matmul_kinds::Gelu)); + zeroB, sumB, beta, C, ldc, nullptr, nullptr, 0, 0.0f, matmul_kinds::Gelu, groupsize)); } // INT4 @@ -1078,7 +1080,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT4) GEMMVERBOSE("xdnn_hgemm_f32u4f32_compute_gelu", xdnn_hgemm_f32u4f32_compute_gelu(transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, - scaleB, zeroB, beta, C, ldc)); + scaleB, zeroB, beta, C, ldc, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT4 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -1105,7 +1107,7 @@ class MMHelper { template void compute_resmul(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc, const InT *res, - int ldres) { + int ldres, int groupsize) { // FP32 if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_sgemm_compute_resmul", @@ -1184,7 +1186,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) GEMMVERBOSE("xdnn_hgemm_f32s8f32_compute_resmul", xdnn_hgemm_f32s8f32_compute_resmul( - transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, res, ldres)); + transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, res, ldres, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -1195,7 +1197,7 @@ class MMHelper { else if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_gemm_f32s8f32_compute_resmul", onednn_amx_gemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, (const int8_t *)packedB, scaleB, - zeroB, sumB, beta, C, ldc, nullptr, res, ldres, 0.0f, matmul_kinds::Resmul)); + zeroB, sumB, beta, C, ldc, nullptr, res, ldres, 0.0f, matmul_kinds::Resmul, groupsize)); } // INT4 @@ -1207,7 +1209,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT4) GEMMVERBOSE("xdnn_hgemm_f32u4f32_compute_resmul", xdnn_hgemm_f32u4f32_compute_resmul(transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, - scaleB, zeroB, beta, C, ldc, res, ldres)); + scaleB, zeroB, beta, C, ldc, res, ldres, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT4 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -1234,7 +1236,7 @@ class MMHelper { template void compute_residential(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc, const float *bias, - const InT *res, int ldres) { + const InT *res, int ldres, int groupsize) { // FP32 if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_sgemm_compute_residential", @@ -1315,7 +1317,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) GEMMVERBOSE("xdnn_hgemm_f32s8f32_compute_residential", xdnn_hgemm_f32s8f32_compute_residential( - transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias, res, ldres)); + transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias, res, ldres, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -1326,7 +1328,7 @@ class MMHelper { else if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_gemm_f32s8f32_compute_residential", onednn_amx_gemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, (const int8_t *)packedB, scaleB, - zeroB, sumB, beta, C, ldc, bias, res, ldres, 0.0f, matmul_kinds::Residential)); + zeroB, sumB, beta, C, ldc, bias, res, ldres, 0.0f, matmul_kinds::Residential, groupsize)); } // INT4 @@ -1338,7 +1340,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT4) GEMMVERBOSE("xdnn_hgemm_f32u4f32_compute_residential", xdnn_hgemm_f32u4f32_compute_residential(transA, M, N, K, alpha, A, lda, - (const XDNN_UINT4x2 *)packedB, scaleB, zeroB, beta, C, ldc, bias, res, ldres)); + (const XDNN_UINT4x2 *)packedB, scaleB, zeroB, beta, C, ldc, bias, res, ldres, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT4 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -1365,7 +1367,7 @@ class MMHelper { template void compute_resext(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc, const float *bias, - float gamma, InT *res, int ldres) { + float gamma, InT *res, int ldres, int groupsize) { // FP32 if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_sgemm_compute_resext", @@ -1473,7 +1475,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) GEMMVERBOSE("xdnn_hgemm_f32s8f32_compute_resext", xdnn_hgemm_f32s8f32_compute_resext(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, - ldc, bias, gamma, res, ldres)); + ldc, bias, gamma, res, ldres, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -1484,7 +1486,7 @@ class MMHelper { else if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_gemm_f32s8f32_compute_resext", onednn_amx_gemm_f32s8f32_compute(transA, M, N, K, alpha, A, lda, (const int8_t *)packedB, scaleB, - zeroB, sumB, beta, C, ldc, bias, res, ldres, gamma, matmul_kinds::Resext)); + zeroB, sumB, beta, C, ldc, bias, res, ldres, gamma, matmul_kinds::Resext, groupsize)); } // INT4 @@ -1496,7 +1498,7 @@ class MMHelper { #elif defined(AVX512_FP16_WEIGHT_ONLY_INT4) GEMMVERBOSE("xdnn_hgemm_f32u4f32_compute_resext", xdnn_hgemm_f32u4f32_compute_resext(transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, - scaleB, zeroB, beta, C, ldc, bias, gamma, res, ldres)); + scaleB, zeroB, beta, C, ldc, bias, gamma, res, ldres, groupsize)); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT4 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -2382,12 +2384,17 @@ class MMHelper { void onednn_amx_gemm_f32s8f32_compute(bool transA, int M, int N, int K, float alpha, const float *A, int lda, const int8_t *B, const float *scaleB, const float *zeroB, const float *sumB, float beta, float *C, int ldc, - const float *bias, const float *res, int ldres, float gamma, matmul_kinds kind) { + const float *bias, const float *res, int ldres, float gamma, matmul_kinds kind, int groupsize) { if (transA || (N % 16) != 0 || alpha != 1.0f || beta != 0.0f) { printf("%s:%d: Not implemented.\n", __FILE__, __LINE__); exit(-1); } + if (groupsize != -1) { + printf("%s:%d: W8A8 with groupsize not implemented.\n", __FILE__, __LINE__); + exit(-1); + } + // split M dimension if M*N is too big const int max_MN = 4 * 1024 * 1024; int numSplit = M * N / max_MN + 1;