Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
360 changes: 358 additions & 2 deletions tensorflow/core/framework/embedding/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <unordered_map>
#include <set>
#include <list>
#include <deque>
#include <limits>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/types.h"
Expand Down Expand Up @@ -128,10 +129,12 @@ class LRUCache : public BatchCache<K> {
mutex mu_;
};

/* Use real frequencies, from 1 to max value of size_t.
Independent implementation. */
template <class K>
class LFUCache : public BatchCache<K> {
class OriginLFUCache : public BatchCache<K> {
public:
LFUCache() {
OriginLFUCache() {
min_freq = std::numeric_limits<size_t>::max();
max_freq = 0;
freq_table.emplace_back(std::pair<std::list<LFUNode>*, int64>(
Expand Down Expand Up @@ -272,6 +275,359 @@ class LFUCache : public BatchCache<K> {
mutex mu_;
};


template <class K>
class LFUNode {
public:
LFUNode(K key, unsigned now) : key(key), freq(0) {}
~LFUNode() {}
K GetKey() { return key; }
size_t GetIndex() { return freq; }
size_t UpdateAndReturnIndex(unsigned now, bool lru_mode) { return ++freq; }
private:
K key;
size_t freq;
};

template <class K>
class AgingNode : public LFUNode<K> {
public:
AgingNode(K key, unsigned now)
: LFUNode<K>(key, now), count(INIT_CNT), index(INIT_CNT), last(now) {}
AgingNode(typename std::list<AgingNode>::iterator that)
: LFUNode<K>(that->key, that->last),
count(that->count),
index(that->index),
last(that->last) {}
~AgingNode() {}
size_t GetIndex() { return index; }
size_t UpdateAndReturnIndex(unsigned now, bool lru_mode) {
Decrease(now);
IncrByProb();
index = lru_mode ? MAX_CNT : count;
return (size_t)index;
}
static const uint8_t INIT_CNT = 5;

private:
void DecrByProb(unsigned period) {
size_t ret1 = rand();
size_t ret2 = RAND_MAX;
ret1 *= UNIT_STEP * DECR_FACTOR;
ret2 *= (period - IGNORE_STEP);
if (count > 0 && ret1 < ret2) count--;
}
void DecrByPeriod(unsigned period) {
unsigned decrease_value = period / UNIT_STEP;
if (decrease_value + INIT_CNT > count)
count = INIT_CNT;
else
count -= decrease_value;
}
void Decrease(unsigned now) {
unsigned period = now >= last
? now - last
: std::numeric_limits<size_t>::max() - last + now;
DecrByPeriod(period);
}
void IncrByProb() {
size_t ret = rand();
ret *= (count - INIT_CNT) * INCR_FACTOR;
if (count < 255 && ret < RAND_MAX) count++;
}

private:
static const uint8_t MIN_CNT = 1;
static const uint8_t MAX_CNT = 255;
static const unsigned INCR_FACTOR = 7;
static const unsigned DECR_FACTOR = 10;
// 32640 = 1 + 2 + ... + 255
static const unsigned UNIT_STEP = 32640 * INCR_FACTOR;
static const unsigned IGNORE_STEP = 0;
K key;
uint8_t count;
uint8_t index;
unsigned last;
};

template <class K, class Node>
class BaseLFUCache : public BatchCache<K> {
public:
using map_iter = typename std::unordered_map<
K, typename std::list<Node>::iterator>::iterator;
BaseLFUCache() {
min_freq = std::numeric_limits<size_t>::max();
max_freq = 0;
freq_table.emplace_back(
std::pair<std::list<Node>*, int64>(new std::list<Node>, 0));
BatchCache<K>::num_hit = 0;
BatchCache<K>::num_miss = 0;
}

size_t size() {
mutex_lock l(mu_);
return key_table.size();
}

size_t get_evic_ids(K* evic_ids, size_t k_size) {
mutex_lock l(mu_);
size_t true_size = 0;
size_t st_freq = min_freq;
for (size_t i = 0; i < k_size && key_table.size() > 0; ++i) {
auto rm_it = freq_table[st_freq].first->back();
key_table.erase(rm_it.GetKey());
evic_ids[i] = rm_it.GetKey();
++true_size;
freq_table[st_freq].first->pop_back();
freq_table[st_freq].second--;
if (freq_table[st_freq].second == 0) {
++st_freq;
while (st_freq <= max_freq) {
if (freq_table[st_freq].second == 0) {
++st_freq;
} else {
break;
}
}
}
}
return true_size;
}

void AddNode(K id, unsigned now) {
Node node(id, now);
size_t index = node.GetIndex();
freq_table[index].first->emplace_front(node);
freq_table[index].second++;
key_table[id] = freq_table[index].first->begin();
min_freq = std::min(min_freq, index);
max_freq = std::max(max_freq, index);
}

void UpdateNode(K id, map_iter it, unsigned now, bool lru_mode) {
Node node = (Node)(*(it->second));
size_t index = node.GetIndex();
freq_table[index].first->erase(it->second);
freq_table[index].second--;
if (freq_table[index].second == 0) {
if (min_freq == index) min_freq += 1;
}
index = node.UpdateAndReturnIndex(now, lru_mode);
if (index == freq_table.size()) {
freq_table.emplace_back(
std::pair<std::list<Node>*, int64>(new std::list<Node>, 0));
}
max_freq = std::max(max_freq, index);
min_freq = std::min(min_freq, index);
freq_table[index].first->emplace_front(node);
freq_table[index].second++;
key_table[id] = freq_table[index].first->begin();
}

void add_to_rank(const K* batch_ids, size_t batch_size) {
mutex_lock l(mu_);
for (size_t i = 0; i < batch_size; ++i) {
K id = batch_ids[i];
auto it = key_table.find(id);
if (it == key_table.end()) {
AddNode(id, 0);
BatchCache<K>::num_miss++;
} else {
UpdateNode(id, it, 0, false);
BatchCache<K>::num_hit++;
}
}
}

void add_to_rank(const K* batch_ids, size_t batch_size,
const int64* batch_version, const int64* batch_freqs) {
// TODO: add to rank accroding to the version of ids
add_to_rank(batch_ids, batch_size);
}

protected:
size_t min_freq;
size_t max_freq;
std::vector<std::pair<std::list<Node>*, int64>> freq_table;
std::unordered_map<K, typename std::list<Node>::iterator> key_table;
mutex mu_;
};

template <class K>
class LFUCache : public BaseLFUCache<K, LFUNode<K>> {
public:
LFUCache() : BaseLFUCache<K, LFUNode<K>>() {}
};

template <class K>
class AgingLFUCache : public BaseLFUCache<K, AgingNode<K>> {
public:
AgingLFUCache() : BaseLFUCache<K, AgingNode<K>>() {
global_step = 0;
for (size_t i = 0; i < AgingNode<K>::INIT_CNT; ++i) {
this->freq_table.emplace_back(std::pair<std::list<AgingNode<K>>*, int64>(
new std::list<AgingNode<K>>, 0));
}
}

void add_to_rank(const K* batch_ids, size_t batch_size) {
mutex_lock l(this->mu_);
for (size_t i = 0; i < batch_size; ++i) {
if (global_step == std::numeric_limits<size_t>::max()) global_step = 0;
global_step++;
K id = batch_ids[i];
auto it = this->key_table.find(id);
if (it == this->key_table.end()) {
this->AddNode(id, global_step);
BatchCache<K>::num_miss++;
} else {
this->UpdateNode(id, it, global_step, false);
BatchCache<K>::num_hit++;
}
}
}

void add_to_rank(const K* batch_ids, size_t batch_size,
const int64* batch_version, const int64* batch_freqs) {
// TODO: add to rank accroding to the version of ids
add_to_rank(batch_ids, batch_size);
}

protected:
size_t global_step;
};

template <class K>
class AutoLRFUCache : public AgingLFUCache<K> {
public:
AutoLRFUCache(int64 cache_capacity)
: AgingLFUCache<K>(),
cache_capacity_(cache_capacity),
state(F0),
lru_mode(false),
prev_hit_rate(-100),
counter_replacement(0),
factor_replacement(1) {}

void add_to_rank(const K* batch_ids, size_t batch_size) {
mutex_lock l(this->mu_);
for (size_t i = 0; i < batch_size; ++i) {
auto_switch();
if (this->global_step == std::numeric_limits<size_t>::max())
this->global_step = 0;
this->global_step++;
if (hit_recent.size() > NORMAL_CHECK_SPAN) {
hit_recent.pop_back();
}
K id = batch_ids[i];
auto it = this->key_table.find(id);
if (it == this->key_table.end()) {
hit_recent.push_front(true);
this->AddNode(id, this->global_step);
BatchCache<K>::num_miss++;
} else {
hit_recent.push_front(false);
counter_replacement++;
this->UpdateNode(id, it, this->global_step, lru_mode);
BatchCache<K>::num_hit++;
}
}
}

void add_to_rank(const K* batch_ids, size_t batch_size,
const int64* batch_version, const int64* batch_freqs) {
// TODO: add to rank accroding to the version of ids
add_to_rank(batch_ids, batch_size);
}

private:
void rebuild() {
if (lru_mode) return;
std::unordered_map<K, AgingNode<K>*> new_table;
for (auto it = this->key_table.begin(); it != this->key_table.end(); it++) {
AgingNode<K>* node = new AgingNode<K>(it->second);
node->UpdateAndReturnIndex(this->global_step, lru_mode);
new_table[node->GetIndex()] = node;
}
this->freq_table.clear();
this->key_table.clear();
for (auto it = new_table.begin(); it != new_table.end(); it++) {
AgingNode<K> *node = (AgingNode<K>*)(it->second);
this->freq_table[node->GetIndex()].first->emplace_front(*node);
this->freq_table[node->GetIndex()].second++;
this->key_table[node->GetKey()] =
this->freq_table[node->GetIndex()].first->begin();
}
}

int get_hit_rate100000(int len = 0) {
int total = hit_recent.size();
if (total <= 0) return 0.0;
if (len <= 0) len = total;
int hit = 0;
for (auto it = hit_recent.begin(); len-- >0 && it != hit_recent.end(); it++)
if(*it) ++hit;
return int(hit * 100000 / total);
}

void mode_switch() {
if (lru_mode) {
lru_mode = false;
rebuild();
} else {
lru_mode = true;
}
}

void auto_switch() {
if ((state == S3 && counter_replacement < FAST_CHECK_SPAN) ||
(counter_replacement < factor_replacement * cache_capacity_))
return;
counter_replacement = 0;
int curr_hit_rate = get_hit_rate100000(0);
if (state == F0) {
// Switch to LRU mode if the hit rate decreased significantly(15%).
if (prev_hit_rate >= 0 && curr_hit_rate < prev_hit_rate * 0.85) {
mode_switch();
state = S1;
}
} else if (state == S1) {
// Switch back to LFU mode if the hit rate did not increase significantly(15%).
if (curr_hit_rate <= 1.15 * prev_hit_rate) {
mode_switch();
state = F0;
} else {
state = S2;
}
} else if (state == S2) {
mode_switch();
state = S3;
} else if (state == S3) {
curr_hit_rate = get_hit_rate100000(FAST_CHECK_SPAN);
if (curr_hit_rate > prev_hit_rate) {
state = F0;
factor_replacement = 1;
} else {
mode_switch();
state = S2;
factor_replacement = factor_replacement * 2;
}
}
if (state != S3) prev_hit_rate = curr_hit_rate;
}

private:
enum State { F0, S1, S2, S3 };
int64 cache_capacity_;
std::deque<bool> hit_recent;
static const unsigned FAST_CHECK_SPAN = 1000;
static const unsigned NORMAL_CHECK_SPAN = 5000;
size_t counter_replacement;
unsigned factor_replacement;
int prev_hit_rate;
bool lru_mode;
State state;
};
} // embedding
} // tensorflow

Expand Down
Loading