Skip to content

Parameter auto-tuning for the marginal-hits strategy #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cachelib/allocator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ add_library (cachelib_allocator
LruTailAgeStrategy.cpp
MarginalHitsOptimizeStrategy.cpp
MarginalHitsStrategy.cpp
MarginalHitsStrategyNew.cpp
memory/AllocationClass.cpp
memory/MemoryAllocator.cpp
memory/MemoryPool.cpp
Expand Down
5 changes: 4 additions & 1 deletion cachelib/allocator/CacheAllocatorConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,9 @@ class CacheAllocatorConfig {

// whether to allow tracking tail hits in MM2Q
bool trackTailHits{false};

// when doing tail hits tracking for MM2Q, do we consider cold tail hits only or both cold and warm tail hits
bool countColdTailHitsOnly{false};

// Memory monitoring config
MemoryMonitor::Config memMonitorConfig;
Expand Down Expand Up @@ -1170,7 +1173,7 @@ bool CacheAllocatorConfig<T>::validateStrategy(

auto type = strategy->getType();
return type != RebalanceStrategy::NumTypes &&
(type != RebalanceStrategy::MarginalHits || trackTailHits);
((type != RebalanceStrategy::MarginalHits && type != RebalanceStrategy::MarginalHitsNew) || trackTailHits);
}

template <typename T>
Expand Down
6 changes: 4 additions & 2 deletions cachelib/allocator/MM2Q.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class MM2Q {
}

// adding extra config after generating the config: tailSize
void addExtraConfig(size_t tSize) { tailSize = tSize; }
void addExtraConfig(size_t tSize, bool coldOnly=false) { tailSize = tSize; coldTailOnly = coldOnly;}

// threshold value in seconds to compare with a node's update time to
// determine if we need to update the position of the node in the linked
Expand Down Expand Up @@ -346,6 +346,8 @@ class MM2Q {
// should not set this manually.
size_t tailSize{0};

bool coldTailOnly{false};

// Minimum interval between reconfigurations. If 0, reconfigure is never
// called.
std::chrono::seconds mmReconfigureIntervalSecs{};
Expand Down Expand Up @@ -1100,7 +1102,7 @@ MMContainerStat MM2Q::Container<T, HookPtr>::getStats() const noexcept {
numHotAccesses_,
numColdAccesses_,
numWarmAccesses_,
computeWeightedAccesses(numWarmTailAccesses_, numColdTailAccesses_)};
config_.coldTailOnly? numColdTailAccesses_ : computeWeightedAccesses(numWarmTailAccesses_, numColdTailAccesses_)};
});
}

Expand Down
223 changes: 223 additions & 0 deletions cachelib/allocator/MarginalHitsStrategyNew.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "cachelib/allocator/MarginalHitsStrategyNew.h"

#include <folly/logging/xlog.h>

#include <algorithm>
#include <functional>

namespace facebook::cachelib {

MarginalHitsStrategyNew::MarginalHitsStrategyNew(Config config)
: RebalanceStrategy(MarginalHits), config_(std::move(config)) {}

RebalanceContext MarginalHitsStrategyNew::pickVictimAndReceiverImpl(
const CacheBase& cache, PoolId pid, const PoolStats& poolStats) {
return pickVictimAndReceiverCandidates(cache, pid, poolStats, false);
}

RebalanceContext MarginalHitsStrategyNew::pickVictimAndReceiverCandidates(
const CacheBase& cache, PoolId pid, const PoolStats& poolStats, bool force) {
const auto config = getConfigCopy();
if (!cache.getPool(pid).allSlabsAllocated()) {
XLOGF(DBG,
"Pool Id: {} does not have all its slabs allocated"
" and does not need rebalancing.",
static_cast<int>(pid));
return kNoOpContext;
}


auto scores = computeClassMarginalHits(pid, poolStats, config.movingAverageParam);
auto classesSet = poolStats.getClassIds();
std::vector<ClassId> classes(classesSet.begin(), classesSet.end());
std::unordered_map<ClassId, bool> validVictim;
std::unordered_map<ClassId, bool> validReceiver;
for (auto it : classes) {
auto acStats = poolStats.mpStats.acStats;
// a class can be a victim only if it has more than config.minSlabs slabs
validVictim[it] = acStats.at(it).totalSlabs() > config.minSlabs;
// a class can be a receiver only if its free memory (free allocs, free
// slabs, etc) is small
validReceiver[it] = acStats.at(it).getTotalFreeMemory() <
config.maxFreeMemSlabs * Slab::kSize;
}
if (classStates_[pid].entities.empty()) {
// initialization
classStates_[pid].entities = classes;
for (auto cid : classes) {
classStates_[pid].smoothedRanks[cid] = 0;
}
}
// we don't rely on this decay anymore
classStates_[pid].updateRankings(scores, 0.0);
RebalanceContext ctx = pickVictimAndReceiverFromRankings(pid, validVictim, validReceiver);

auto numRequestObserved = computeNumRequests(pid, poolStats);
if(!force && numRequestObserved < config.minRequestsObserved) {
XLOGF(DBG, "haven't observed enough requests: {}/{}, wait until next round", numRequestObserved, config.minRequestsObserved);
ctx = kNoOpContext;
}
if(!force && ctx.isEffective()) {
//extra filterings
auto receiverScore = scores.at(ctx.receiverClassId);
auto victimScore = scores.at(ctx.victimClassId);
auto improvement = receiverScore - victimScore;
auto improvementRatio = improvement / (victimScore == 0 ? 1 : victimScore);
ctx.diffValue = improvement;
if ((config.minDiff > 0 && improvement < config.minDiff) ||
(config.minDiffRatio > 0 && improvementRatio < config.minDiffRatio)){
XLOGF(DBG, "Not enough to trigger rebalancing, receiver id: {}, victim id: {}, receiver score: {}, victim score: {}, improvement: {}, improvement ratio: {}, thresh1: {}, thresh2: {}",
ctx.receiverClassId, ctx.victimClassId, receiverScore, victimScore, improvement, improvementRatio, config.minDiff, config.minDiffRatio);
ctx = kNoOpContext;
} else {
XLOGF(DBG, "rebalancing, receiver id: {}, victim id: {}, receiver score: {}, victim score: {}, improvement: {}, improvement ratio: {}",
ctx.receiverClassId, ctx.victimClassId, receiverScore, victimScore, improvement, improvementRatio);

}
}

if(!ctx.isEffective()){
ctx = kNoOpContext;
}
auto& poolState = getPoolState(pid);
auto deltaRequestsSinceLastDecay = computeRequestsSinceLastDecay(pid, poolStats);
if((ctx.isEffective() || !config.onlyUpdateHitIfRebalance) || deltaRequestsSinceLastDecay >= config.minDecayInterval) {
for (const auto i : poolStats.getClassIds()) {
poolState[i].updateTailHits(poolStats, config.movingAverageParam);
}
}

if(numRequestObserved >= config.minRequestsObserved) {
for (const auto i : poolStats.getClassIds()) {
poolState[i].updateRequests(poolStats);
}
}

// self-tuning threshold for the next round.
if(ctx.isEffective()){
// max window size: 2 * n_classes

size_t classWithHits = 0;
for (const auto& cid : classes) {
if (poolState.at(cid).deltaHits(poolStats) > 0) {
++classWithHits;
}
}
recordRebalanceEvent(pid, ctx, classWithHits * 2);
auto effectiveMoveRate = queryEffectiveMoveRate(pid);
auto windowSize = getRebalanceEventQueueSize(pid);
XLOGF(DBG,
"Rebalancing: effective move rate = {}, window size = {}, diff = {}, threshold = {}, ({}->{})",
effectiveMoveRate,
windowSize, ctx.diffValue, config.minDiff, static_cast<int>(ctx.victimClassId), static_cast<int>(ctx.receiverClassId));

if(effectiveMoveRate <= config.emrLow && windowSize >= config.thresholdIncMinWindowSize) {
if(config.thresholdAI) {
auto currentMin = getMinDiffValueFromRebalanceEvents(pid);
if(updateMinDff(currentMin + config.thresholdAIADStep)) {
clearPoolRebalanceEvent(pid);
}
} else if (config.thresholdMI){
if(updateMinDff(config.minDiff * config.thresholdMIMDFactor)) {
clearPoolRebalanceEvent(pid);
}
}

} else if (effectiveMoveRate >= config.emrHigh && windowSize >= classWithHits) {
if(config.thresholdAD) {
if(updateMinDff(std::max(2.0, config.minDiff - config.thresholdAIADStep))) {
clearPoolRebalanceEvent(pid);
}
} else if (config.thresholdMD){
if(updateMinDff(std::max(2.0, config.minDiff / config.thresholdMIMDFactor))) {
clearPoolRebalanceEvent(pid);
}
}
}
}

return ctx;
}

ClassId MarginalHitsStrategyNew::pickVictimImpl(const CacheBase& cache,
PoolId pid,
const PoolStats& stats) {
return pickVictimAndReceiverCandidates(cache, pid, stats, true).victimClassId;
}

std::unordered_map<ClassId, double>
MarginalHitsStrategyNew::computeClassMarginalHits(PoolId pid,
const PoolStats& poolStats,
double decayFactor) {
const auto& poolState = getPoolState(pid);
std::unordered_map<ClassId, double> scores;
for (auto info : poolState) {
if (info.id != Slab::kInvalidClassId) {
// this score is the latest delta.
scores[info.id] = info.getDecayedMarginalHits(poolStats, decayFactor);
}
}
return scores;
}

size_t MarginalHitsStrategyNew::computeNumRequests(
PoolId pid, const PoolStats& poolStats) const {
const auto& poolState = getPoolState(pid);
size_t totalRequests = 0;
auto classesSet = poolStats.getClassIds();
for (const auto& cid : classesSet) {
totalRequests += poolState.at(cid).deltaRequests(poolStats);
}
return totalRequests;
}

size_t MarginalHitsStrategyNew::computeRequestsSinceLastDecay(
PoolId pid, const PoolStats& poolStats) const {
const auto& poolState = getPoolState(pid);
size_t totalRequests = 0;
auto classesSet = poolStats.getClassIds();
for (const auto& cid : classesSet) {
totalRequests += poolState.at(cid).deltaRequestsSinceLastDecay(poolStats);
}
return totalRequests;
}

RebalanceContext MarginalHitsStrategyNew::pickVictimAndReceiverFromRankings(
PoolId pid,
const std::unordered_map<ClassId, bool>& validVictim,
const std::unordered_map<ClassId, bool>& validReceiver) {
auto victimAndReceiver = classStates_[pid].pickVictimAndReceiverFromRankings(
validVictim, validReceiver, Slab::kInvalidClassId);
RebalanceContext ctx{victimAndReceiver.first, victimAndReceiver.second};
if (ctx.victimClassId == Slab::kInvalidClassId ||
ctx.receiverClassId == Slab::kInvalidClassId ||
ctx.victimClassId == ctx.receiverClassId) {
return kNoOpContext;
}

XLOGF(DBG,
"Rebalancing: receiver = {}, smoothed rank = {}, victim = {}, smoothed "
"rank = {}",
static_cast<int>(ctx.receiverClassId),
classStates_[pid].smoothedRanks[ctx.receiverClassId],
static_cast<int>(ctx.victimClassId),
classStates_[pid].smoothedRanks[ctx.victimClassId]);
return ctx;
}
} // namespace facebook::cachelib
Loading
Loading