diff --git a/tensorflow/core/framework/embedding/cache.h b/tensorflow/core/framework/embedding/cache.h index 45960bcf8ea..61ba21a5f86 100644 --- a/tensorflow/core/framework/embedding/cache.h +++ b/tensorflow/core/framework/embedding/cache.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/types.h" @@ -128,10 +129,12 @@ class LRUCache : public BatchCache { mutex mu_; }; +/* Use real frequencies, from 1 to max value of size_t. + Independent implementation. */ template -class LFUCache : public BatchCache { +class OriginLFUCache : public BatchCache { public: - LFUCache() { + OriginLFUCache() { min_freq = std::numeric_limits::max(); max_freq = 0; freq_table.emplace_back(std::pair*, int64>( @@ -272,6 +275,359 @@ class LFUCache : public BatchCache { mutex mu_; }; + +template +class LFUNode { + public: + LFUNode(K key, unsigned now) : key(key), freq(0) {} + ~LFUNode() {} + K GetKey() { return key; } + size_t GetIndex() { return freq; } + size_t UpdateAndReturnIndex(unsigned now, bool lru_mode) { return ++freq; } + private: + K key; + size_t freq; +}; + +template +class AgingNode : public LFUNode { + public: + AgingNode(K key, unsigned now) + : LFUNode(key, now), count(INIT_CNT), index(INIT_CNT), last(now) {} + AgingNode(typename std::list::iterator that) + : LFUNode(that->key, that->last), + count(that->count), + index(that->index), + last(that->last) {} + ~AgingNode() {} + size_t GetIndex() { return index; } + size_t UpdateAndReturnIndex(unsigned now, bool lru_mode) { + Decrease(now); + IncrByProb(); + index = lru_mode ? MAX_CNT : count; + return (size_t)index; + } + static const uint8_t INIT_CNT = 5; + + private: + void DecrByProb(unsigned period) { + size_t ret1 = rand(); + size_t ret2 = RAND_MAX; + ret1 *= UNIT_STEP * DECR_FACTOR; + ret2 *= (period - IGNORE_STEP); + if (count > 0 && ret1 < ret2) count--; + } + void DecrByPeriod(unsigned period) { + unsigned decrease_value = period / UNIT_STEP; + if (decrease_value + INIT_CNT > count) + count = INIT_CNT; + else + count -= decrease_value; + } + void Decrease(unsigned now) { + unsigned period = now >= last + ? now - last + : std::numeric_limits::max() - last + now; + DecrByPeriod(period); + } + void IncrByProb() { + size_t ret = rand(); + ret *= (count - INIT_CNT) * INCR_FACTOR; + if (count < 255 && ret < RAND_MAX) count++; + } + + private: + static const uint8_t MIN_CNT = 1; + static const uint8_t MAX_CNT = 255; + static const unsigned INCR_FACTOR = 7; + static const unsigned DECR_FACTOR = 10; + // 32640 = 1 + 2 + ... + 255 + static const unsigned UNIT_STEP = 32640 * INCR_FACTOR; + static const unsigned IGNORE_STEP = 0; + K key; + uint8_t count; + uint8_t index; + unsigned last; +}; + +template +class BaseLFUCache : public BatchCache { + public: + using map_iter = typename std::unordered_map< + K, typename std::list::iterator>::iterator; + BaseLFUCache() { + min_freq = std::numeric_limits::max(); + max_freq = 0; + freq_table.emplace_back( + std::pair*, int64>(new std::list, 0)); + BatchCache::num_hit = 0; + BatchCache::num_miss = 0; + } + + size_t size() { + mutex_lock l(mu_); + return key_table.size(); + } + + size_t get_evic_ids(K* evic_ids, size_t k_size) { + mutex_lock l(mu_); + size_t true_size = 0; + size_t st_freq = min_freq; + for (size_t i = 0; i < k_size && key_table.size() > 0; ++i) { + auto rm_it = freq_table[st_freq].first->back(); + key_table.erase(rm_it.GetKey()); + evic_ids[i] = rm_it.GetKey(); + ++true_size; + freq_table[st_freq].first->pop_back(); + freq_table[st_freq].second--; + if (freq_table[st_freq].second == 0) { + ++st_freq; + while (st_freq <= max_freq) { + if (freq_table[st_freq].second == 0) { + ++st_freq; + } else { + break; + } + } + } + } + return true_size; + } + + void AddNode(K id, unsigned now) { + Node node(id, now); + size_t index = node.GetIndex(); + freq_table[index].first->emplace_front(node); + freq_table[index].second++; + key_table[id] = freq_table[index].first->begin(); + min_freq = std::min(min_freq, index); + max_freq = std::max(max_freq, index); + } + + void UpdateNode(K id, map_iter it, unsigned now, bool lru_mode) { + Node node = (Node)(*(it->second)); + size_t index = node.GetIndex(); + freq_table[index].first->erase(it->second); + freq_table[index].second--; + if (freq_table[index].second == 0) { + if (min_freq == index) min_freq += 1; + } + index = node.UpdateAndReturnIndex(now, lru_mode); + if (index == freq_table.size()) { + freq_table.emplace_back( + std::pair*, int64>(new std::list, 0)); + } + max_freq = std::max(max_freq, index); + min_freq = std::min(min_freq, index); + freq_table[index].first->emplace_front(node); + freq_table[index].second++; + key_table[id] = freq_table[index].first->begin(); + } + + void add_to_rank(const K* batch_ids, size_t batch_size) { + mutex_lock l(mu_); + for (size_t i = 0; i < batch_size; ++i) { + K id = batch_ids[i]; + auto it = key_table.find(id); + if (it == key_table.end()) { + AddNode(id, 0); + BatchCache::num_miss++; + } else { + UpdateNode(id, it, 0, false); + BatchCache::num_hit++; + } + } + } + + void add_to_rank(const K* batch_ids, size_t batch_size, + const int64* batch_version, const int64* batch_freqs) { + // TODO: add to rank accroding to the version of ids + add_to_rank(batch_ids, batch_size); + } + + protected: + size_t min_freq; + size_t max_freq; + std::vector*, int64>> freq_table; + std::unordered_map::iterator> key_table; + mutex mu_; +}; + +template +class LFUCache : public BaseLFUCache> { + public: + LFUCache() : BaseLFUCache>() {} +}; + +template +class AgingLFUCache : public BaseLFUCache> { + public: + AgingLFUCache() : BaseLFUCache>() { + global_step = 0; + for (size_t i = 0; i < AgingNode::INIT_CNT; ++i) { + this->freq_table.emplace_back(std::pair>*, int64>( + new std::list>, 0)); + } + } + + void add_to_rank(const K* batch_ids, size_t batch_size) { + mutex_lock l(this->mu_); + for (size_t i = 0; i < batch_size; ++i) { + if (global_step == std::numeric_limits::max()) global_step = 0; + global_step++; + K id = batch_ids[i]; + auto it = this->key_table.find(id); + if (it == this->key_table.end()) { + this->AddNode(id, global_step); + BatchCache::num_miss++; + } else { + this->UpdateNode(id, it, global_step, false); + BatchCache::num_hit++; + } + } + } + + void add_to_rank(const K* batch_ids, size_t batch_size, + const int64* batch_version, const int64* batch_freqs) { + // TODO: add to rank accroding to the version of ids + add_to_rank(batch_ids, batch_size); + } + + protected: + size_t global_step; +}; + +template +class AutoLRFUCache : public AgingLFUCache { + public: + AutoLRFUCache(int64 cache_capacity) + : AgingLFUCache(), + cache_capacity_(cache_capacity), + state(F0), + lru_mode(false), + prev_hit_rate(-100), + counter_replacement(0), + factor_replacement(1) {} + + void add_to_rank(const K* batch_ids, size_t batch_size) { + mutex_lock l(this->mu_); + for (size_t i = 0; i < batch_size; ++i) { + auto_switch(); + if (this->global_step == std::numeric_limits::max()) + this->global_step = 0; + this->global_step++; + if (hit_recent.size() > NORMAL_CHECK_SPAN) { + hit_recent.pop_back(); + } + K id = batch_ids[i]; + auto it = this->key_table.find(id); + if (it == this->key_table.end()) { + hit_recent.push_front(true); + this->AddNode(id, this->global_step); + BatchCache::num_miss++; + } else { + hit_recent.push_front(false); + counter_replacement++; + this->UpdateNode(id, it, this->global_step, lru_mode); + BatchCache::num_hit++; + } + } + } + + void add_to_rank(const K* batch_ids, size_t batch_size, + const int64* batch_version, const int64* batch_freqs) { + // TODO: add to rank accroding to the version of ids + add_to_rank(batch_ids, batch_size); + } + + private: + void rebuild() { + if (lru_mode) return; + std::unordered_map*> new_table; + for (auto it = this->key_table.begin(); it != this->key_table.end(); it++) { + AgingNode* node = new AgingNode(it->second); + node->UpdateAndReturnIndex(this->global_step, lru_mode); + new_table[node->GetIndex()] = node; + } + this->freq_table.clear(); + this->key_table.clear(); + for (auto it = new_table.begin(); it != new_table.end(); it++) { + AgingNode *node = (AgingNode*)(it->second); + this->freq_table[node->GetIndex()].first->emplace_front(*node); + this->freq_table[node->GetIndex()].second++; + this->key_table[node->GetKey()] = + this->freq_table[node->GetIndex()].first->begin(); + } + } + + int get_hit_rate100000(int len = 0) { + int total = hit_recent.size(); + if (total <= 0) return 0.0; + if (len <= 0) len = total; + int hit = 0; + for (auto it = hit_recent.begin(); len-- >0 && it != hit_recent.end(); it++) + if(*it) ++hit; + return int(hit * 100000 / total); + } + + void mode_switch() { + if (lru_mode) { + lru_mode = false; + rebuild(); + } else { + lru_mode = true; + } + } + + void auto_switch() { + if ((state == S3 && counter_replacement < FAST_CHECK_SPAN) || + (counter_replacement < factor_replacement * cache_capacity_)) + return; + counter_replacement = 0; + int curr_hit_rate = get_hit_rate100000(0); + if (state == F0) { + // Switch to LRU mode if the hit rate decreased significantly(15%). + if (prev_hit_rate >= 0 && curr_hit_rate < prev_hit_rate * 0.85) { + mode_switch(); + state = S1; + } + } else if (state == S1) { + // Switch back to LFU mode if the hit rate did not increase significantly(15%). + if (curr_hit_rate <= 1.15 * prev_hit_rate) { + mode_switch(); + state = F0; + } else { + state = S2; + } + } else if (state == S2) { + mode_switch(); + state = S3; + } else if (state == S3) { + curr_hit_rate = get_hit_rate100000(FAST_CHECK_SPAN); + if (curr_hit_rate > prev_hit_rate) { + state = F0; + factor_replacement = 1; + } else { + mode_switch(); + state = S2; + factor_replacement = factor_replacement * 2; + } + } + if (state != S3) prev_hit_rate = curr_hit_rate; + } + + private: + enum State { F0, S1, S2, S3 }; + int64 cache_capacity_; + std::deque hit_recent; + static const unsigned FAST_CHECK_SPAN = 1000; + static const unsigned NORMAL_CHECK_SPAN = 5000; + size_t counter_replacement; + unsigned factor_replacement; + int prev_hit_rate; + bool lru_mode; + State state; +}; } // embedding } // tensorflow diff --git a/tensorflow/core/kernels/embedding_variable_ops_test.cc b/tensorflow/core/kernels/embedding_variable_ops_test.cc index 923aaf4c38c..da67f9dda94 100644 --- a/tensorflow/core/kernels/embedding_variable_ops_test.cc +++ b/tensorflow/core/kernels/embedding_variable_ops_test.cc @@ -1176,6 +1176,54 @@ TEST(EmbeddingVariableTest, TestLFUCache) { } } +TEST(EmbeddingVariableTest, TestAgingLFUCache) { + BatchCache* cache = new AgingLFUCache(); + int num_ids = 30; + int num_access = 100; + int num_evict = 50; + int64 ids[num_access] = {0}; + int64 evict_ids[num_evict] = {0}; + bool evict_ids_map[num_ids] = {false}; + for (int i = 0; i < num_access; i++){ + ids[i] = i % num_ids; + } + cache->add_to_rank(ids, num_access); + int64 size = cache->get_evic_ids(evict_ids, num_evict); + ASSERT_EQ(size, num_ids); + ASSERT_EQ(cache->size(), 0); + for (int i = 0; i < num_ids; i++) { + ASSERT_EQ(evict_ids[i] < num_ids, true); + evict_ids_map[evict_ids[i]] = true; + } + for (int id = 0; id < num_ids; id++) { + ASSERT_EQ(evict_ids_map[id], true); + } +} + +TEST(EmbeddingVariableTest, TestAutoLRFUCache) { + BatchCache* cache = new AutoLRFUCache(100); + int num_ids = 30; + int num_access = 100; + int num_evict = 50; + int64 ids[num_access] = {0}; + int64 evict_ids[num_evict] = {0}; + bool evict_ids_map[num_ids] = {false}; + for (int i = 0; i < num_access; i++){ + ids[i] = i % num_ids; + } + cache->add_to_rank(ids, num_access); + int64 size = cache->get_evic_ids(evict_ids, num_evict); + ASSERT_EQ(size, num_ids); + ASSERT_EQ(cache->size(), 0); + for (int i = 0; i < num_ids; i++) { + ASSERT_EQ(evict_ids[i] < num_ids, true); + evict_ids_map[evict_ids[i]] = true; + } + for (int id = 0; id < num_ids; id++) { + ASSERT_EQ(evict_ids_map[id], true); + } +} + TEST(EmbeddingVariableTest, TestCacheRestore) { int64 value_size = 4; Tensor value(DT_FLOAT, TensorShape({value_size}));