diff --git a/src/common/kvcache_mgr.h b/src/common/kvcache_mgr.h index c2fa271b..ace2919d 100644 --- a/src/common/kvcache_mgr.h +++ b/src/common/kvcache_mgr.h @@ -13,10 +13,11 @@ // limitations under the License. // ============================================================================ #pragma once - +#include #include + +#include "environment.h" #include "kvcache_tensor.h" -#include namespace xft { @@ -41,6 +42,8 @@ class KVCacheMgrImpl : public KVCacheMgrImplBase { this->headNum_ = headNum; this->headSize_ = headSize; this->layers_ = layers; + // The KV Cache location configured in "KV_CACHE_LOCATION" + this->allocNode = Env::getInstance().getKVCacheLocation(); } ~KVCacheMgrImpl() { @@ -89,7 +92,7 @@ class KVCacheMgrImpl : public KVCacheMgrImplBase { // User specified maxSeqLen needs to be <= model's configured maxSeqLen auto maxLen = maxSeqLen > 0 ? std::min(maxSeqLen, maxSeqLen_) : maxSeqLen_; for (int i = 0; i < 2 * layers_; ++i) { - cache[i].resize(maxLen, 1, headNum_, headSize_); + cache[i].resize(maxLen, 1, headNum_, headSize_, this->allocNode); } sequenceCaches.insert({seqID, cache}); @@ -186,6 +189,7 @@ class KVCacheMgrImpl : public KVCacheMgrImplBase { int headNum_; int headSize_; int layers_; + int allocNode; }; class KVCacheMgr { diff --git a/src/common/kvcache_tensor.h b/src/common/kvcache_tensor.h index 438bd1f9..a37d7ade 100644 --- a/src/common/kvcache_tensor.h +++ b/src/common/kvcache_tensor.h @@ -24,6 +24,7 @@ #include "allocator.h" #include "bfloat16.h" #include "float16.h" +#include "numa_allocator.h" extern bool kvTrans(); @@ -67,23 +68,23 @@ template class KVCacheTensor { public: KVCacheTensor() - : maxSeqLen(0), batchSize(0), headNum(0), headSize(0), data(nullptr), allocSize(0), scales(nullptr) {} + : maxSeqLen(0), batchSize(0), headNum(0), headSize(0), data(nullptr), allocSize(0), scales(nullptr), scalesAllocSize(0) {} ~KVCacheTensor() { - if (this->data) { free(this->data); } - if (this->scales) { free(this->scales); } + if (this->data) { xft_numa_free(this->data, allocSize); } + if (this->scales) { xft_numa_free(this->scales, scalesAllocSize); } } - void resize(int maxSeqLen, int batchSize, int headNum, int headSize) { + void resize(int maxSeqLen, int batchSize, int headNum, int headSize, int allocNode) { this->maxSeqLen = maxSeqLen; this->batchSize = batchSize; this->headNum = headNum; this->headSize = headSize; - uint64_t requiredSize = (uint64_t)maxSeqLen * batchSize * headNum * headSize; + uint64_t requiredSize = (uint64_t)maxSeqLen * batchSize * headNum * headSize * sizeof(T); if (requiredSize > allocSize) { - if (this->data) { free(this->data); } - this->data = (T *)xft::alloc(requiredSize * sizeof(T)); + if (this->data) { xft_numa_free(this->data, allocSize); } + this->data = (T *)xft_numa_alloc_onnode(requiredSize, allocNode); if (!this->data) { printf("Failed to alloc mem for KV Cache [%d][%d][%d][%d].\n", maxSeqLen, batchSize, headNum, headSize); exit(-1); @@ -91,8 +92,16 @@ class KVCacheTensor { allocSize = requiredSize; } - if (this->scales) { free(this->scales); } - this->scales = (float *)xft::alloc((uint64_t)maxSeqLen * batchSize * headNum * sizeof(float)); + requiredSize = (uint64_t)maxSeqLen * batchSize * headNum * sizeof(float); + if (requiredSize > scalesAllocSize) { + if (this->scales) { xft_numa_free(this->scales, scalesAllocSize); } + this->scales = (float *)xft_numa_alloc_onnode(requiredSize, allocNode); + if (!this->scales) { + printf("Failed to alloc mem for KV Cache scales [%d][%d][%d][%d].\n", maxSeqLen, batchSize, headNum, headSize); + exit(-1); + } + scalesAllocSize = requiredSize; + } } int getBatchSize() const { return batchSize; } @@ -188,15 +197,15 @@ class KVCacheTensor { * initSeqLen: initial sequence length, which is the prompt token size * accSeqLen: accumulated sequence length */ - void reorder(int *idx, int size, int initSeqLen, int accSeqLen) { + void reorder(int *idx, int size, int initSeqLen, int accSeqLen, int allocNode) { const int cols = this->getHeadNum() * this->getHeadSize(); const int batchSize = this->getBatchSize(); T *pdata = this->data + initSeqLen * batchSize * cols; // Temporary buffer used for reorder - T *extraKeyBuf = (T *)xft::alloc((batchSize - 1) * cols * sizeof(T)); - + uint64_t requiredSize = (uint64_t)(batchSize - 1) * cols * sizeof(T); + T *extraKeyBuf = (T *)xft_numa_alloc_onnode(requiredSize, allocNode); for (int seq = initSeqLen; seq < accSeqLen; ++seq) { // Reorder is not needed for the first few lines int extraBufIdx = 0; int remapped[batchSize]; @@ -260,7 +269,7 @@ class KVCacheTensor { pdata += batchSize * cols; } - free(extraKeyBuf); + xft_numa_free(extraKeyBuf, requiredSize); } private: @@ -327,4 +336,5 @@ class KVCacheTensor { // The scale factor for each head (if T is int8) float *scales; + uint64_t scalesAllocSize; }; diff --git a/src/models/kvcache_manager.cpp b/src/models/kvcache_manager.cpp index 13ccec92..c1ad1ced 100644 --- a/src/models/kvcache_manager.cpp +++ b/src/models/kvcache_manager.cpp @@ -12,28 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================ -#include "kvcache_manager.h" #include #include #include #include #include "allocator.h" #include "bfloat16.h" +#include "environment.h" #include "float16.h" +#include "kvcache_manager.h" template void KVCacheManager::resize(int maxSeqLen, int batchSize, int headsPerSplit, int headSize, bool prefix) { + // The KV Cache location configured in "KV_CACHE_LOCATION" + this->allocNode = Env::getInstance().getKVCacheLocation(); if (prefix && this->cachedPrefixKeys == nullptr) { this->cachedPrefixKeys = new KVCacheTensor[layers]; this->cachedPrefixValues = new KVCacheTensor[layers]; } for (int i = 0; i < this->layers; ++i) { if (prefix) { - this->cachedPrefixKeys[i].resize(maxSeqLen, 1, headsPerSplit, headSize); - this->cachedPrefixValues[i].resize(maxSeqLen, 1, headsPerSplit, headSize); + this->cachedPrefixKeys[i].resize(maxSeqLen, 1, headsPerSplit, headSize, this->allocNode); + this->cachedPrefixValues[i].resize(maxSeqLen, 1, headsPerSplit, headSize, this->allocNode); } else { - this->cachedKeys[i].resize(maxSeqLen, batchSize, headsPerSplit, headSize); - this->cachedValues[i].resize(maxSeqLen, batchSize, headsPerSplit, headSize); + this->cachedKeys[i].resize(maxSeqLen, batchSize, headsPerSplit, headSize, this->allocNode); + this->cachedValues[i].resize(maxSeqLen, batchSize, headsPerSplit, headSize, this->allocNode); } } } @@ -100,10 +103,10 @@ void KVCacheManager::reorderCache(int *idx, int size, int initSeqLen, int layer = i / 2; if (i % 2 == 0) { KVCacheTensor &keyTensor = this->getKey(layer); - keyTensor.reorder(idx, size, initSeqLen, accSeqLen); + keyTensor.reorder(idx, size, initSeqLen, accSeqLen, this->allocNode); } else { KVCacheTensor &valueTensor = this->getValue(layer); - valueTensor.reorder(idx, size, initSeqLen, accSeqLen); + valueTensor.reorder(idx, size, initSeqLen, accSeqLen, this->allocNode); } } } diff --git a/src/models/kvcache_manager.h b/src/models/kvcache_manager.h index 6f593029..430e88a5 100644 --- a/src/models/kvcache_manager.h +++ b/src/models/kvcache_manager.h @@ -69,6 +69,7 @@ class KVCacheManager { void reorderCache(int *idx, int size, int initSeqLen, int accSeqLen); private: + int allocNode; int layers; // how many layers KVCacheTensor *cachedKeys; // all accumulated keys KVCacheTensor *cachedValues; // all accumulated values diff --git a/src/utils/environment.h b/src/utils/environment.h index ddcb2df4..483d791a 100644 --- a/src/utils/environment.h +++ b/src/utils/environment.h @@ -71,6 +71,9 @@ class Env { // get Primitive Cache M int getPrimitiveCacheM() { return primitiveCacheM; } + // get KV Cache Location + int getKVCacheLocation() { return kvCacheLocation; } + private: Env() { // init Verbose @@ -111,6 +114,9 @@ class Env { // init Primitive Cache M initPrimitiveCacheM(); + + // init KV Cache Location + initKVCacheLocation(); } // Verbose @@ -281,4 +287,16 @@ class Env { primitiveCacheM = 256; } } + + // KV_CACHE_LOCATION + int kvCacheLocation = -1; + void initKVCacheLocation() { + // The KV Cache location configured in "KV_CACHE_LOCATION" + char *xft_kvcache_location_value = getenv("KV_CACHE_LOCATION"); + if (xft_kvcache_location_value != NULL) { + int value = atoi(xft_kvcache_location_value); + if (value >= 0) + kvCacheLocation = value; + } + } }; \ No newline at end of file