diff --git a/cachelib/allocator/CMakeLists.txt b/cachelib/allocator/CMakeLists.txt index aad79b0a5..603cfeeee 100644 --- a/cachelib/allocator/CMakeLists.txt +++ b/cachelib/allocator/CMakeLists.txt @@ -49,6 +49,7 @@ add_library (cachelib_allocator LruTailAgeStrategy.cpp MarginalHitsOptimizeStrategy.cpp MarginalHitsStrategy.cpp + MarginalHitsStrategyNew.cpp memory/AllocationClass.cpp memory/MemoryAllocator.cpp memory/MemoryPool.cpp diff --git a/cachelib/allocator/CacheAllocatorConfig.h b/cachelib/allocator/CacheAllocatorConfig.h index da02939bb..2cb880b1a 100644 --- a/cachelib/allocator/CacheAllocatorConfig.h +++ b/cachelib/allocator/CacheAllocatorConfig.h @@ -536,6 +536,9 @@ class CacheAllocatorConfig { // whether to allow tracking tail hits in MM2Q bool trackTailHits{false}; + + // when doing tail hits tracking for MM2Q, do we consider cold tail hits only or both cold and warm tail hits + bool countColdTailHitsOnly{false}; // Memory monitoring config MemoryMonitor::Config memMonitorConfig; @@ -1170,7 +1173,7 @@ bool CacheAllocatorConfig::validateStrategy( auto type = strategy->getType(); return type != RebalanceStrategy::NumTypes && - (type != RebalanceStrategy::MarginalHits || trackTailHits); + ((type != RebalanceStrategy::MarginalHits && type != RebalanceStrategy::MarginalHitsNew) || trackTailHits); } template diff --git a/cachelib/allocator/MM2Q.h b/cachelib/allocator/MM2Q.h index cece17e0e..6cc9a00cd 100644 --- a/cachelib/allocator/MM2Q.h +++ b/cachelib/allocator/MM2Q.h @@ -305,7 +305,7 @@ class MM2Q { } // adding extra config after generating the config: tailSize - void addExtraConfig(size_t tSize) { tailSize = tSize; } + void addExtraConfig(size_t tSize, bool coldOnly=false) { tailSize = tSize; coldTailOnly = coldOnly;} // threshold value in seconds to compare with a node's update time to // determine if we need to update the position of the node in the linked @@ -346,6 +346,8 @@ class MM2Q { // should not set this manually. size_t tailSize{0}; + bool coldTailOnly{false}; + // Minimum interval between reconfigurations. If 0, reconfigure is never // called. std::chrono::seconds mmReconfigureIntervalSecs{}; @@ -1100,7 +1102,7 @@ MMContainerStat MM2Q::Container::getStats() const noexcept { numHotAccesses_, numColdAccesses_, numWarmAccesses_, - computeWeightedAccesses(numWarmTailAccesses_, numColdTailAccesses_)}; + config_.coldTailOnly? numColdTailAccesses_ : computeWeightedAccesses(numWarmTailAccesses_, numColdTailAccesses_)}; }); } diff --git a/cachelib/allocator/MarginalHitsStrategyNew.cpp b/cachelib/allocator/MarginalHitsStrategyNew.cpp new file mode 100644 index 000000000..036b5cf49 --- /dev/null +++ b/cachelib/allocator/MarginalHitsStrategyNew.cpp @@ -0,0 +1,223 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cachelib/allocator/MarginalHitsStrategyNew.h" + +#include + +#include +#include + +namespace facebook::cachelib { + +MarginalHitsStrategyNew::MarginalHitsStrategyNew(Config config) + : RebalanceStrategy(MarginalHits), config_(std::move(config)) {} + +RebalanceContext MarginalHitsStrategyNew::pickVictimAndReceiverImpl( + const CacheBase& cache, PoolId pid, const PoolStats& poolStats) { + return pickVictimAndReceiverCandidates(cache, pid, poolStats, false); +} + +RebalanceContext MarginalHitsStrategyNew::pickVictimAndReceiverCandidates( + const CacheBase& cache, PoolId pid, const PoolStats& poolStats, bool force) { + const auto config = getConfigCopy(); + if (!cache.getPool(pid).allSlabsAllocated()) { + XLOGF(DBG, + "Pool Id: {} does not have all its slabs allocated" + " and does not need rebalancing.", + static_cast(pid)); + return kNoOpContext; + } + + + auto scores = computeClassMarginalHits(pid, poolStats, config.movingAverageParam); + auto classesSet = poolStats.getClassIds(); + std::vector classes(classesSet.begin(), classesSet.end()); + std::unordered_map validVictim; + std::unordered_map validReceiver; + for (auto it : classes) { + auto acStats = poolStats.mpStats.acStats; + // a class can be a victim only if it has more than config.minSlabs slabs + validVictim[it] = acStats.at(it).totalSlabs() > config.minSlabs; + // a class can be a receiver only if its free memory (free allocs, free + // slabs, etc) is small + validReceiver[it] = acStats.at(it).getTotalFreeMemory() < + config.maxFreeMemSlabs * Slab::kSize; + } + if (classStates_[pid].entities.empty()) { + // initialization + classStates_[pid].entities = classes; + for (auto cid : classes) { + classStates_[pid].smoothedRanks[cid] = 0; + } + } + // we don't rely on this decay anymore + classStates_[pid].updateRankings(scores, 0.0); + RebalanceContext ctx = pickVictimAndReceiverFromRankings(pid, validVictim, validReceiver); + + auto numRequestObserved = computeNumRequests(pid, poolStats); + if(!force && numRequestObserved < config.minRequestsObserved) { + XLOGF(DBG, "haven't observed enough requests: {}/{}, wait until next round", numRequestObserved, config.minRequestsObserved); + ctx = kNoOpContext; + } + if(!force && ctx.isEffective()) { + //extra filterings + auto receiverScore = scores.at(ctx.receiverClassId); + auto victimScore = scores.at(ctx.victimClassId); + auto improvement = receiverScore - victimScore; + auto improvementRatio = improvement / (victimScore == 0 ? 1 : victimScore); + ctx.diffValue = improvement; + if ((config.minDiff > 0 && improvement < config.minDiff) || + (config.minDiffRatio > 0 && improvementRatio < config.minDiffRatio)){ + XLOGF(DBG, "Not enough to trigger rebalancing, receiver id: {}, victim id: {}, receiver score: {}, victim score: {}, improvement: {}, improvement ratio: {}, thresh1: {}, thresh2: {}", + ctx.receiverClassId, ctx.victimClassId, receiverScore, victimScore, improvement, improvementRatio, config.minDiff, config.minDiffRatio); + ctx = kNoOpContext; + } else { + XLOGF(DBG, "rebalancing, receiver id: {}, victim id: {}, receiver score: {}, victim score: {}, improvement: {}, improvement ratio: {}", + ctx.receiverClassId, ctx.victimClassId, receiverScore, victimScore, improvement, improvementRatio); + + } + } + + if(!ctx.isEffective()){ + ctx = kNoOpContext; + } + auto& poolState = getPoolState(pid); + auto deltaRequestsSinceLastDecay = computeRequestsSinceLastDecay(pid, poolStats); + if((ctx.isEffective() || !config.onlyUpdateHitIfRebalance) || deltaRequestsSinceLastDecay >= config.minDecayInterval) { + for (const auto i : poolStats.getClassIds()) { + poolState[i].updateTailHits(poolStats, config.movingAverageParam); + } + } + + if(numRequestObserved >= config.minRequestsObserved) { + for (const auto i : poolStats.getClassIds()) { + poolState[i].updateRequests(poolStats); + } + } + + // self-tuning threshold for the next round. + if(ctx.isEffective()){ + // max window size: 2 * n_classes + + size_t classWithHits = 0; + for (const auto& cid : classes) { + if (poolState.at(cid).deltaHits(poolStats) > 0) { + ++classWithHits; + } + } + recordRebalanceEvent(pid, ctx, classWithHits * 2); + auto effectiveMoveRate = queryEffectiveMoveRate(pid); + auto windowSize = getRebalanceEventQueueSize(pid); + XLOGF(DBG, + "Rebalancing: effective move rate = {}, window size = {}, diff = {}, threshold = {}, ({}->{})", + effectiveMoveRate, + windowSize, ctx.diffValue, config.minDiff, static_cast(ctx.victimClassId), static_cast(ctx.receiverClassId)); + + if(effectiveMoveRate <= config.emrLow && windowSize >= config.thresholdIncMinWindowSize) { + if(config.thresholdAI) { + auto currentMin = getMinDiffValueFromRebalanceEvents(pid); + if(updateMinDff(currentMin + config.thresholdAIADStep)) { + clearPoolRebalanceEvent(pid); + } + } else if (config.thresholdMI){ + if(updateMinDff(config.minDiff * config.thresholdMIMDFactor)) { + clearPoolRebalanceEvent(pid); + } + } + + } else if (effectiveMoveRate >= config.emrHigh && windowSize >= classWithHits) { + if(config.thresholdAD) { + if(updateMinDff(std::max(2.0, config.minDiff - config.thresholdAIADStep))) { + clearPoolRebalanceEvent(pid); + } + } else if (config.thresholdMD){ + if(updateMinDff(std::max(2.0, config.minDiff / config.thresholdMIMDFactor))) { + clearPoolRebalanceEvent(pid); + } + } + } + } + + return ctx; +} + +ClassId MarginalHitsStrategyNew::pickVictimImpl(const CacheBase& cache, + PoolId pid, + const PoolStats& stats) { + return pickVictimAndReceiverCandidates(cache, pid, stats, true).victimClassId; +} + +std::unordered_map +MarginalHitsStrategyNew::computeClassMarginalHits(PoolId pid, + const PoolStats& poolStats, + double decayFactor) { + const auto& poolState = getPoolState(pid); + std::unordered_map scores; + for (auto info : poolState) { + if (info.id != Slab::kInvalidClassId) { + // this score is the latest delta. + scores[info.id] = info.getDecayedMarginalHits(poolStats, decayFactor); + } + } + return scores; +} + +size_t MarginalHitsStrategyNew::computeNumRequests( + PoolId pid, const PoolStats& poolStats) const { + const auto& poolState = getPoolState(pid); + size_t totalRequests = 0; + auto classesSet = poolStats.getClassIds(); + for (const auto& cid : classesSet) { + totalRequests += poolState.at(cid).deltaRequests(poolStats); + } + return totalRequests; +} + +size_t MarginalHitsStrategyNew::computeRequestsSinceLastDecay( + PoolId pid, const PoolStats& poolStats) const { + const auto& poolState = getPoolState(pid); + size_t totalRequests = 0; + auto classesSet = poolStats.getClassIds(); + for (const auto& cid : classesSet) { + totalRequests += poolState.at(cid).deltaRequestsSinceLastDecay(poolStats); + } + return totalRequests; +} + +RebalanceContext MarginalHitsStrategyNew::pickVictimAndReceiverFromRankings( + PoolId pid, + const std::unordered_map& validVictim, + const std::unordered_map& validReceiver) { + auto victimAndReceiver = classStates_[pid].pickVictimAndReceiverFromRankings( + validVictim, validReceiver, Slab::kInvalidClassId); + RebalanceContext ctx{victimAndReceiver.first, victimAndReceiver.second}; + if (ctx.victimClassId == Slab::kInvalidClassId || + ctx.receiverClassId == Slab::kInvalidClassId || + ctx.victimClassId == ctx.receiverClassId) { + return kNoOpContext; + } + + XLOGF(DBG, + "Rebalancing: receiver = {}, smoothed rank = {}, victim = {}, smoothed " + "rank = {}", + static_cast(ctx.receiverClassId), + classStates_[pid].smoothedRanks[ctx.receiverClassId], + static_cast(ctx.victimClassId), + classStates_[pid].smoothedRanks[ctx.victimClassId]); + return ctx; +} +} // namespace facebook::cachelib \ No newline at end of file diff --git a/cachelib/allocator/MarginalHitsStrategyNew.h b/cachelib/allocator/MarginalHitsStrategyNew.h new file mode 100644 index 000000000..5b87b9ee6 --- /dev/null +++ b/cachelib/allocator/MarginalHitsStrategyNew.h @@ -0,0 +1,142 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cachelib/allocator/MarginalHitsState.h" +#include "cachelib/allocator/RebalanceStrategy.h" + +namespace facebook { +namespace cachelib { + +// This strategy computes number of hits in the tail slab of LRU to estimate +// the potential (given one more slab, how many more hits can this LRU serve). +// And use a smoothed ranking of those potentials to decide victim and receiver. +class MarginalHitsStrategyNew : public RebalanceStrategy { + public: + // Config class for marginal hits strategy + struct Config : public BaseConfig { + // parameter for moving average, to smooth the ranking + double movingAverageParam{0.3}; + + // minimum number of slabs to retain in every allocation class. + unsigned int minSlabs{1}; + + // maximum free memory (equivalent to this many slabs) in every allocation + // class + unsigned int maxFreeMemSlabs{1}; + + bool onlyUpdateHitIfRebalance{true}; + + // enforcing thresholds between the victim and receiver class + double minDiff{2.0}; + double minDiffRatio{0.00}; + + ////// these parameters are for controlling the threshold auto-tuning + unsigned int thresholdIncMinWindowSize{5}; + bool thresholdAI{true}; + bool thresholdMI{false}; + bool thresholdAD{false}; + bool thresholdMD{true}; + + double emrLow{0.5}; + double emrHigh{0.95}; + double thresholdAIADStep{2.0}; + double thresholdMIMDFactor{2.0}; + /////////////////////////// + + uint64_t minRequestsObserved{50000}; + uint64_t minDecayInterval{50000}; + + Config() noexcept {} + explicit Config(double param) noexcept : Config(param, 1, 1) {} + Config(double param, unsigned int minSlab, unsigned int maxFree) noexcept + : movingAverageParam(param), + minSlabs(minSlab), + maxFreeMemSlabs(maxFree) {} + }; + + // Update the config. This will not affect the current rebalancing, but + // will take effect in the next round + void updateConfig(const BaseConfig& baseConfig) override final { + std::lock_guard l(configLock_); + config_ = static_cast(baseConfig); + } + + bool updateMinDff(double newValue) { + if(config_.minDiff == newValue){ + return false; + } + std::lock_guard l(configLock_); + XLOGF(DBG, "marginal-hits, threshold auto-tuning, updating from {} to {}", config_.minDiff, newValue); + config_.minDiff = newValue; + return true; + } + + explicit MarginalHitsStrategyNew(Config config = {}); + + protected: + // This returns a copy of the current config. + // This ensures that we're always looking at the same config even though + // someone else may have updated the config during rebalancing + Config getConfigCopy() const { + std::lock_guard l(configLock_); + return config_; + } + + // pick victim and receiver classes from a pool + RebalanceContext pickVictimAndReceiverImpl( + const CacheBase& cache, + PoolId pid, + const PoolStats& poolStats) override final; + + // pick victim class from a pool to shrink + ClassId pickVictimImpl(const CacheBase& cache, + PoolId pid, + const PoolStats& poolStats) override final; + + size_t computeNumRequests(PoolId pid, const PoolStats& poolStats) const; + + size_t computeRequestsSinceLastDecay(PoolId pid, const PoolStats& poolStats) const; + + private: + // compute delta of tail hits for every class in this pool + std::unordered_map computeClassMarginalHits( + PoolId pid, const PoolStats& poolStats, double decayFactor); + + // pick victim and receiver according to smoothed rankings + RebalanceContext pickVictimAndReceiverFromRankings( + PoolId pid, + const std::unordered_map& validVictim, + const std::unordered_map& validReceiver); + + RebalanceContext pickVictimAndReceiverCandidates( + const CacheBase& cache, + PoolId pid, + const PoolStats& poolStats, + bool force); + + // marginal hits states for classes in each pools + std::unordered_map> classStates_; + + // Config for this strategy, this can be updated anytime. + // Do not access this directly, always use `getConfig()` to + // obtain a copy first + Config config_; + mutable std::mutex configLock_; +}; +} // namespace cachelib +} // namespace facebook \ No newline at end of file diff --git a/cachelib/allocator/RebalanceInfo.h b/cachelib/allocator/RebalanceInfo.h index a6288839f..e4384cba6 100644 --- a/cachelib/allocator/RebalanceInfo.h +++ b/cachelib/allocator/RebalanceInfo.h @@ -47,6 +47,13 @@ struct Info { // accumulative number of hits in the tail slab of this allocation class uint64_t accuTailHits{0}; + double decayedAccuTailHits{0.0}; + + // accumalative number of requests seen for this allocation class + uint64_t numRequests{0}; + + uint64_t numRequestsAtLastDecay{0}; + // TODO(sugak) this is changed to unblock the LLVM upgrade The fix is not // completely understood, but it's a safe change T16521551 - Info() noexcept // = default; @@ -55,8 +62,11 @@ struct Info { unsigned long long slabs, unsigned long long evicts, uint64_t h, - uint64_t th) noexcept - : id(_id), nSlabs(slabs), evictions(evicts), hits(h), accuTailHits(th) {} + uint64_t th, + double dath, + uint64_t nr, + uint64_t nrld) noexcept + : id(_id), nSlabs(slabs), evictions(evicts), hits(h), accuTailHits(th), decayedAccuTailHits(dath), numRequests(nr), numRequestsAtLastDecay(nrld) {} // number of rounds we hold off for when we acquire a slab. static constexpr unsigned int kNumHoldOffRounds = 10; @@ -101,6 +111,22 @@ struct Info { return poolStats.numHitsForClass(id) - hits; } + uint64_t deltaRequests(const PoolStats& poolStats) const { + const auto& cacheStats = poolStats.cacheStats.at(id); + auto totalRequests = poolStats.numHitsForClass(id) + cacheStats.allocAttempts; + return totalRequests > numRequests + ? totalRequests - numRequests + : 0; + } + + uint64_t deltaRequestsSinceLastDecay(const PoolStats& poolStats) const { + const auto& cacheStats = poolStats.cacheStats.at(id); + auto totalRequests = poolStats.numHitsForClass(id) + cacheStats.allocAttempts; + return totalRequests > numRequestsAtLastDecay + ? totalRequests - numRequestsAtLastDecay + : 0; + } + // return the delta of alloc failures for this alloc class from the current // state. // @@ -144,6 +170,11 @@ struct Info { accuTailHits; } + double getDecayedMarginalHits(const PoolStats& poolStats, double decayFactor=0.0) const { + // decayed past + now + return decayedAccuTailHits + getMarginalHits(poolStats) * (1 - decayFactor); + } + // returns true if the hold off is active for this alloc class. bool isOnHoldOff() const noexcept { return holdOffRemaining > 0; } @@ -162,6 +193,19 @@ struct Info { hits = poolStats.numHitsForClass(id); } + void updateRequests(const PoolStats& poolStats) noexcept { + const auto& cacheStats = poolStats.cacheStats.at(id); + numRequests = poolStats.numHitsForClass(id) + cacheStats.allocAttempts; + } + + void updateTailHits(const PoolStats& poolStats, double decayFactor=0.0) noexcept { + const auto& cacheStats = poolStats.cacheStats.at(id); + decayedAccuTailHits = (decayedAccuTailHits + getMarginalHits(poolStats)) * decayFactor; + accuTailHits = cacheStats.containerStat.numTailAccesses; + numRequestsAtLastDecay = poolStats.numHitsForClass(id) + cacheStats.allocAttempts; + } + + // updates the current record to store the current state of slabs and the // evictions we see. void updateRecord(const PoolStats& poolStats) { @@ -175,7 +219,8 @@ struct Info { evictions = cacheStats.numEvictions(); // update tail hits - accuTailHits = cacheStats.containerStat.numTailAccesses; + // we'll update this separately + //accuTailHits = cacheStats.containerStat.numTailAccesses; allocFailures = cacheStats.allocFailures; } diff --git a/cachelib/allocator/RebalanceStrategy.cpp b/cachelib/allocator/RebalanceStrategy.cpp index 2d02931ec..9ea313f31 100644 --- a/cachelib/allocator/RebalanceStrategy.cpp +++ b/cachelib/allocator/RebalanceStrategy.cpp @@ -59,7 +59,10 @@ void RebalanceStrategy::initPoolState(PoolId pid, const PoolStats& stats) { curr[id] = Info{id, stats.mpStats.acStats.at(id).totalSlabs(), stats.cacheStats.at(id).numEvictions(), stats.numHitsForClass(id), - stats.cacheStats.at(id).containerStat.numTailAccesses}; + stats.cacheStats.at(id).containerStat.numTailAccesses, + 0, + stats.numHitsForClass(id) + stats.cacheStats.at(id).allocAttempts, + stats.numHitsForClass(id) + stats.cacheStats.at(id).allocAttempts}; // hits + allocs => nr of requests } } @@ -321,4 +324,67 @@ T RebalanceStrategy::executeAndRecordCurrentState( return rv; } +///// for keeping rebalance decision histories +void RebalanceStrategy::recordRebalanceEvent(PoolId pid, RebalanceContext ctx, size_t maxQueueSize) { + if(ctx.isEffective()) { + auto& eventQueue = recentRebalanceEvents_[pid]; + eventQueue.emplace_back(ctx); + if (eventQueue.size() > maxQueueSize) { + eventQueue.pop_front(); + } + } +} + +unsigned int RebalanceStrategy::getRebalanceEventQueueSize(PoolId pid) const { + const auto it = recentRebalanceEvents_.find(pid); + if (it == recentRebalanceEvents_.end()) { + return 0; + } + return it->second.size(); +} + +void RebalanceStrategy::clearPoolRebalanceEvent(PoolId pid) { + recentRebalanceEvents_.erase(pid); +} + +double RebalanceStrategy::queryEffectiveMoveRate(PoolId pid) const{ + const auto it = recentRebalanceEvents_.find(pid); + if (it == recentRebalanceEvents_.end() || it->second.empty()) { + return 1.0; + } + + const auto& events = it->second; + std::unordered_map netChanges; + + for (const auto& ctx : events) { + netChanges[ctx.victimClassId]--; + netChanges[ctx.receiverClassId]++; + } + + int currentAbsNet = 0; + for (const auto& [classId, net] : netChanges) { + currentAbsNet += std::abs(net); + } + int totalEffectiveMoves = currentAbsNet / 2; + + return static_cast(totalEffectiveMoves) / events.size(); +} + +double RebalanceStrategy::getMinDiffValueFromRebalanceEvents(PoolId pid) const { + const auto it = recentRebalanceEvents_.find(pid); + if (it == recentRebalanceEvents_.end() || it->second.empty()) { + return 0.0; // Return 0.0 if the queue is empty or the pool ID is not found + } + + const auto& events = it->second; + + // Find the minimum diffValue in the queue + double minDiffValue = std::numeric_limits::max(); + for (const auto& ctx : events) { + minDiffValue = std::min(minDiffValue, ctx.diffValue); + } + + return minDiffValue; +} + } // namespace facebook::cachelib diff --git a/cachelib/allocator/RebalanceStrategy.h b/cachelib/allocator/RebalanceStrategy.h index 07f090235..15ddb8239 100644 --- a/cachelib/allocator/RebalanceStrategy.h +++ b/cachelib/allocator/RebalanceStrategy.h @@ -28,9 +28,35 @@ struct RebalanceContext { ClassId victimClassId{Slab::kInvalidClassId}; ClassId receiverClassId{Slab::kInvalidClassId}; + // to support multiple pairs at a time + std::vector> victimReceiverPairs{}; + // tracking the diff between victim and receiver + double diffValue{0.0}; + RebalanceContext() = default; RebalanceContext(ClassId victim, ClassId receiver) : victimClassId(victim), receiverClassId(receiver) {} + + explicit RebalanceContext(const std::vector>& pairs) + : victimReceiverPairs(pairs) {} + + bool isEffective() const { + auto isPairEffective = [](ClassId victim, ClassId receiver) { + return victim != Slab::kInvalidClassId && + receiver != Slab::kInvalidClassId && + victim != receiver; + }; + + bool singleEffective = isPairEffective(victimClassId, receiverClassId); + + bool pairEffective = std::any_of( + victimReceiverPairs.begin(), victimReceiverPairs.end(), + [&](const std::pair& p) { + return isPairEffective(p.first, p.second); + }); + + return singleEffective || pairEffective; + } }; // Base class for rebalance strategy. @@ -48,6 +74,7 @@ class RebalanceStrategy { PickNothingOrTest = 0, Random, MarginalHits, + MarginalHitsNew, FreeMem, HitsPerSlab, LruTailAge, @@ -81,6 +108,17 @@ class RebalanceStrategy { virtual void updateConfig(const BaseConfig&) {} + + void recordRebalanceEvent(PoolId pid, RebalanceContext ctx, size_t maxQueueSize); + + double getMinDiffValueFromRebalanceEvents(PoolId pid) const; + + unsigned int getRebalanceEventQueueSize(PoolId pid) const; + + void clearPoolRebalanceEvent(PoolId pid); + + double queryEffectiveMoveRate(PoolId pid) const; + Type getType() const { return type_; } std::string getTypeString() const { @@ -91,6 +129,8 @@ class RebalanceStrategy { return "Random"; case MarginalHits: return "MarginalHits"; + case MarginalHitsNew: + return "MarginalHitsNew"; case FreeMem: return "FreeMem"; case HitsPerSlab: @@ -192,6 +232,8 @@ class RebalanceStrategy { size_t threshold, const PoolState& prevState); + std::unordered_map> recentRebalanceEvents_; + private: // picks any of the class id ordered by the total slabs. ClassId pickAnyClassIdForResizing(const CacheBase& cache, diff --git a/cachelib/cachebench/cache/Cache.h b/cachelib/cachebench/cache/Cache.h index 71404fd52..7ac116adf 100644 --- a/cachelib/cachebench/cache/Cache.h +++ b/cachelib/cachebench/cache/Cache.h @@ -513,10 +513,16 @@ Cache::Cache(const CacheConfig& config, nandBytesBegin_{fetchNandWrites()}, itemRecords_(config_.enableItemDestructorCheck) { constexpr size_t MB = 1024ULL * 1024ULL; - + + if (config_.rebalanceStrategy == "marginal-hits" || + config_.rebalanceStrategy == "marginal-hits-new") { + allocatorConfig_.enableTailHitsTracking(); + } allocatorConfig_.enablePoolRebalancing( config_.getRebalanceStrategy(), std::chrono::seconds(config_.poolRebalanceIntervalSec)); + + allocatorConfig_.countColdTailHitsOnly = config_.countColdTailHitsOnly; if (config_.moveOnSlabRelease && movingSync != nullptr) { allocatorConfig_.enableMovingOnSlabRelease( diff --git a/cachelib/cachebench/main.cpp b/cachelib/cachebench/main.cpp index ddc056839..30bd11930 100644 --- a/cachelib/cachebench/main.cpp +++ b/cachelib/cachebench/main.cpp @@ -55,6 +55,7 @@ DEFINE_string(progress_stats_file, DEFINE_int32(timeout_seconds, 0, "Maximum allowed seconds for running test. 0 means no timeout"); +DEFINE_bool(enable_debug_log, false, "Enable debug logging"); struct sigaction act; std::unique_ptr runnerInstance; @@ -146,6 +147,10 @@ int main(int argc, char** argv) { return 1; } + if(FLAGS_enable_debug_log) { + folly::LoggerDB::get().setLevel("", folly::LogLevel::DBG); + } + CacheBenchConfig config(FLAGS_json_test_config); std::cout << "Welcome to OSS version of cachebench" << std::endl; #endif diff --git a/cachelib/cachebench/runner/Stressor.cpp b/cachelib/cachebench/runner/Stressor.cpp index 0bca4438b..8f398b18f 100644 --- a/cachelib/cachebench/runner/Stressor.cpp +++ b/cachelib/cachebench/runner/Stressor.cpp @@ -24,6 +24,7 @@ #include "cachelib/cachebench/workload/BinaryKVReplayGenerator.h" #include "cachelib/cachebench/workload/BlockChunkReplayGenerator.h" #include "cachelib/cachebench/workload/KVReplayGenerator.h" +#include "cachelib/cachebench/workload/OGBinaryReplayGenerator.h" #include "cachelib/cachebench/workload/OnlineGenerator.h" #include "cachelib/cachebench/workload/PieceWiseReplayGenerator.h" #include "cachelib/cachebench/workload/SimpleFlashBenchmarkGenerator.h" @@ -147,6 +148,8 @@ std::unique_ptr makeGenerator(const StressorConfig& config) { return std::make_unique(config); } else if (config.generator == "block-replay") { return std::make_unique(config); + } else if (config.generator == "oracle-general-replay") { + return std::make_unique(config); } else if (config.generator == "binary-replay") { return std::make_unique(config); } else if (config.generator.empty() || config.generator == "workload") { diff --git a/cachelib/cachebench/util/CacheConfig.cpp b/cachelib/cachebench/util/CacheConfig.cpp index 6d8f40874..e13d61747 100644 --- a/cachelib/cachebench/util/CacheConfig.cpp +++ b/cachelib/cachebench/util/CacheConfig.cpp @@ -17,6 +17,8 @@ #include "cachelib/cachebench/util/CacheConfig.h" #include "cachelib/allocator/HitsPerSlabStrategy.h" +#include "cachelib/allocator/MarginalHitsStrategyNew.h" +#include "cachelib/allocator/MarginalHitsStrategy.h" #include "cachelib/allocator/LruTailAgeStrategy.h" #include "cachelib/allocator/RandomStrategy.h" @@ -33,6 +35,8 @@ CacheConfig::CacheConfig(const folly::dynamic& configJson) { JSONSetVal(configJson, rebalanceMinSlabs); JSONSetVal(configJson, rebalanceDiffRatio); + JSONSetVal(configJson, countColdTailHitsOnly); + JSONSetVal(configJson, htBucketPower); JSONSetVal(configJson, htLockPower); @@ -135,6 +139,12 @@ std::shared_ptr CacheConfig::getRebalanceStrategy() const { auto config = HitsPerSlabStrategy::Config{ rebalanceDiffRatio, static_cast(rebalanceMinSlabs)}; return std::make_shared(config); + } else if (rebalanceStrategy == "marginal-hits") { + return std::make_shared( + MarginalHitsStrategy::Config{}); + } else if (rebalanceStrategy == "marginal-hits-new") { + return std::make_shared( + MarginalHitsStrategyNew::Config{}); } else { // use random strategy to just trigger some slab release. return std::make_shared( diff --git a/cachelib/cachebench/util/CacheConfig.h b/cachelib/cachebench/util/CacheConfig.h index d1db6cbe1..df52357f0 100644 --- a/cachelib/cachebench/util/CacheConfig.h +++ b/cachelib/cachebench/util/CacheConfig.h @@ -78,6 +78,8 @@ struct CacheConfig : public JSONConfig { double rebalanceDiffRatio{0.25}; bool moveOnSlabRelease{false}; + bool countColdTailHitsOnly{false}; + uint64_t htBucketPower{22}; // buckets in hash table uint64_t htLockPower{20}; // locks in hash table diff --git a/cachelib/cachebench/util/Config.cpp b/cachelib/cachebench/util/Config.cpp index 133074e50..b276b00ee 100644 --- a/cachelib/cachebench/util/Config.cpp +++ b/cachelib/cachebench/util/Config.cpp @@ -57,6 +57,10 @@ StressorConfig::StressorConfig(const folly::dynamic& configJson) { JSONSetVal(configJson, traceFileNames); JSONSetVal(configJson, configPath); + JSONSetVal(configJson, zstdTrace); + JSONSetVal(configJson, compressed); + JSONSetVal(configJson, ignoreLargeReq); + JSONSetVal(configJson, cachePieceSize); JSONSetVal(configJson, maxCachePieces); @@ -95,7 +99,7 @@ StressorConfig::StressorConfig(const folly::dynamic& configJson) { // If you added new fields to the configuration, update the JSONSetVal // to make them available for the json configs and increment the size // below - checkCorrectSize(); + checkCorrectSize(); } bool StressorConfig::usesChainedItems() const { diff --git a/cachelib/cachebench/util/Config.h b/cachelib/cachebench/util/Config.h index dcb2ea3b6..d5f35c7a9 100644 --- a/cachelib/cachebench/util/Config.h +++ b/cachelib/cachebench/util/Config.h @@ -303,6 +303,10 @@ struct StressorConfig : public JSONConfig { std::string traceFileName{}; std::vector traceFileNames{}; + bool zstdTrace{false}; + bool compressed{true}; + bool ignoreLargeReq{false}; + // location of the path for the files referenced inside the json. If not // specified, it defaults to the path of the json file being parsed. std::string configPath{}; diff --git a/cachelib/cachebench/workload/OGBinaryReplayGenerator.h b/cachelib/cachebench/workload/OGBinaryReplayGenerator.h new file mode 100644 index 000000000..822db20b8 --- /dev/null +++ b/cachelib/cachebench/workload/OGBinaryReplayGenerator.h @@ -0,0 +1,460 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "cachelib/cachebench/cache/Cache.h" +#include "cachelib/cachebench/util/Exceptions.h" +#include "cachelib/cachebench/util/Parallel.h" +#include "cachelib/cachebench/util/Request.h" +#include "cachelib/cachebench/workload/ReplayGeneratorBase.h" +#include "cachelib/cachebench/workload/ZstdReader.h" +#include "cachelib/allocator/memory/Slab.h" + +namespace facebook { +namespace cachelib { +namespace cachebench { + +struct OGReqWrapper { + OGReqWrapper() = default; + + OGReqWrapper(const OGReqWrapper& other) + : key_(other.key_), + sizes_(other.sizes_), + req_(key_, + sizes_.begin(), + sizes_.end(), + reinterpret_cast(this), + other.req_), + repeats_(other.repeats_) {} + + void updateKey(const std::string& key) { + key_ = key; + // Request's key is now std::string_view + req_.key = key_; + } + + // current outstanding key + std::string key_; + std::vector sizes_{1}; + // current outstanding req object + // Use 'this' as the request ID, so that this object can be + // identified on completion (i.e., notifyResult call) + Request req_{key_, sizes_.begin(), sizes_.end(), OpType::kGet, + reinterpret_cast(this)}; + + // number of times to issue the current req object + // before fetching a new line from the trace + uint32_t repeats_{0}; +}; + +// OGBinaryReplayGenerator generates the cachelib requests based the trace data +// read from the given trace file(s). +// OGBinaryReplayGenerator supports amplifying the key population by appending +// suffixes (i.e., stream ID) to each key read from the trace file. +// In order to minimize the contentions for the request submission queues +// which might need to be dispatched by multiple stressor threads, +// the requests are sharded to each stressor by doing hashing over the key. +class OGBinaryReplayGenerator : public ReplayGeneratorBase { + public: + // Default order is clock_time,object_id,object_size,next_access_vtime + enum SampleFields : uint8_t { + CLOCK_TIME = 0, + OBJECT_ID, + OBJECT_SIZE, + NEXT_ACCESS_VTIME, + END + }; + + const ColumnTable columnTable_ = { + {SampleFields::CLOCK_TIME, true, {"clock_time"}}, /* required */ + {SampleFields::OBJECT_ID, true, {"object_id"}}, /* required */ + {SampleFields::OBJECT_SIZE, true, {"object_size"}}, /* required */ + {SampleFields::NEXT_ACCESS_VTIME, true, {"next_access_vtime"}}, /* required + */ + }; + + explicit OGBinaryReplayGenerator(const StressorConfig& config) + : ReplayGeneratorBase(config), traceStream_(config, 0, columnTable_), zstdReader_() { + if(config.zstdTrace){ + zstdReader_.open(config.traceFileName, config.compressed); + XLOGF(INFO, "Reading zstd trace file"); + } + for (uint32_t i = 0; i < numShards_; ++i) { + stressorCtxs_.emplace_back(std::make_unique(i)); + } + + genWorker_ = std::thread([this] { + folly::setThreadName("cb_replay_gen"); + genRequests(); + }); + + XLOGF(INFO, + "Started OGBinaryReplayGenerator (amp factor {}, # of stressor " + "threads {})", + ampFactor_, numShards_); + } + + virtual ~OGBinaryReplayGenerator() { + XCHECK(shouldShutdown()); + if (genWorker_.joinable()) { + genWorker_.join(); + } + } + + // getReq generates the next request from the trace file. + const Request& getReq( + uint8_t, + std::mt19937_64&, + std::optional lastRequestId = std::nullopt) override; + + void renderStats(uint64_t, std::ostream& out) const override { + out << std::endl << "== OGBinaryReplayGenerator Stats ==" << std::endl; + + out << folly::sformat("{}: {:.2f} million (parse error: {})", + "Total Processed Samples", + (double)parseSuccess.load() / 1e6, parseError.load()) + << std::endl; + } + + void notifyResult(uint64_t requestId, OpResultType result) override; + + void markFinish() override { getStressorCtx().markFinish(); } + + // Parse the request from the trace line and set the OGReqWrapper + bool parseRequest(const std::string& line, std::unique_ptr& req); + + // for unit test + bool setHeaderRow(const std::string& header) { + return config_.zstdTrace ? true : traceStream_.setHeaderRow(header); + } + + private: + // Interval at which the submission queue is polled when it is either + // full (producer) or empty (consumer). + // We use polling with the delay since the ProducerConsumerQueue does not + // support blocking read or writes with a timeout + static constexpr uint64_t checkIntervalUs_ = 100; + static constexpr size_t kMaxRequests = 10000; + static constexpr size_t kMinKeySize = 16; + static constexpr size_t maxSlabSize = 1ULL << facebook::cachelib::Slab::kNumSlabBits; + + using ReqQueue = folly::ProducerConsumerQueue>; + + // StressorCtx keeps track of the state including the submission queues + // per stressor thread. Since there is only one request generator thread, + // lock-free ProducerConsumerQueue is used for performance reason. + // Also, separate queue which is dispatched ahead of any requests in the + // submission queue is used for keeping track of the requests which need to be + // resubmitted (i.e., a request having remaining repeat count); there could + // be more than one requests outstanding for async stressor while only one + // can be outstanding for sync stressor + struct StressorCtx { + explicit StressorCtx(uint32_t id) + : id_(id), reqQueue_(std::in_place_t{}, kMaxRequests) {} + + bool isFinished() { return finished_.load(std::memory_order_relaxed); } + void markFinish() { finished_.store(true, std::memory_order_relaxed); } + + uint32_t id_{0}; + std::queue> resubmitQueue_; + folly::cacheline_aligned reqQueue_; + // Thread that finish its operations mark it here, so we will skip + // further request on its shard + std::atomic finished_{false}; + }; + + // Read next trace line from TraceFileStream and fill OGReqWrapper + std::unique_ptr getReqInternal(); + + std::unique_ptr getReqInternalZstd(); + + // Used to assign stressorIdx_ + std::atomic incrementalIdx_{0}; + + // A sticky index assigned to each stressor threads that calls into + // the generator. + folly::ThreadLocalPtr stressorIdx_; + + // Vector size is equal to the # of stressor threads; + // stressorIdx_ is used to index. + std::vector> stressorCtxs_; + + TraceFileStream traceStream_; + + ZstdReader zstdReader_; + + std::thread genWorker_; + + // Used to signal end of file as EndOfTrace exception + std::atomic eof{false}; + + // Stats + std::atomic parseError = 0; + std::atomic parseSuccess = 0; + + void genRequests(); + + void setEOF() { eof.store(true, std::memory_order_relaxed); } + bool isEOF() { return eof.load(std::memory_order_relaxed); } + + inline StressorCtx& getStressorCtx(size_t shardId) { + XCHECK_LT(shardId, numShards_); + return *stressorCtxs_[shardId]; + } + + inline StressorCtx& getStressorCtx() { + if (!stressorIdx_.get()) { + stressorIdx_.reset(new uint32_t(incrementalIdx_++)); + } + + return getStressorCtx(*stressorIdx_); + } +}; + +inline bool OGBinaryReplayGenerator::parseRequest( + const std::string& line, std::unique_ptr& req) { + if (!traceStream_.setNextLine(line)) { + return false; + } + + auto sizeField = + traceStream_.template getField(SampleFields::OBJECT_SIZE); + if (!sizeField.hasValue()) { + return false; + } + + // Set key + auto parsedKey = traceStream_.template getField<>(SampleFields::OBJECT_ID).value(); + req->updateKey(std::string{parsedKey}); + + // Convert timestamp to seconds. + // todo: clarify time precision + auto timestampField = + traceStream_.template getField(SampleFields::CLOCK_TIME); + if (timestampField.hasValue()) { + uint64_t timestampRaw = timestampField.value(); + uint64_t timestampSeconds = timestampRaw / timestampFactor_; + req->req_.timestamp = timestampSeconds; + } + + size_t chunkSize = 1 * 1024 * 1024; // 1 MB + size_t objSize = sizeField.value(); + + if (objSize > chunkSize) { + size_t numChunks = (objSize + chunkSize - 1) / chunkSize; + req->sizes_.clear(); + req->sizes_.reserve(numChunks); + + for (size_t i = 0; i < numChunks; ++i) { + size_t currentChunkSize = (i == numChunks - 1) + ? (objSize % chunkSize == 0 ? chunkSize : objSize % chunkSize) + : chunkSize; + req->sizes_.push_back(currentChunkSize); + } + req->req_.setOp(OpType::kAddChained); + } else { + req->sizes_.clear(); + req->sizes_.resize(1); + req->sizes_[0] = objSize; + req->req_.setOp(OpType::kGet); + } + req->req_.sizeBegin = req->sizes_.begin(); + req->req_.sizeEnd = req->sizes_.end(); + + req->repeats_ = 1; + if (!req->repeats_) { + return false; + } + if (config_.ignoreOpCount) { + req->repeats_ = 1; + } + + return true; +} + +inline std::unique_ptr OGBinaryReplayGenerator::getReqInternal() { + auto reqWrapper = std::make_unique(); + + do { + std::string line; + traceStream_.getline(line); // can throw + + if (!parseRequest(line, reqWrapper)) { + parseError++; + XLOG_N_PER_MS(ERR, 10, 1000) << folly::sformat( + "Parsing error (total {}): {}", parseError.load(), line); + } else { + size_t totalSize = std::accumulate( + reqWrapper->sizes_.begin(), reqWrapper->sizes_.end(), size_t(0)); + totalSize += reqWrapper->key_.length() + 32; + + if (config_.ignoreLargeReq && totalSize >= maxSlabSize) { + return getReqInternal(); + } + parseSuccess++; + } + } while (reqWrapper->repeats_ == 0); + + return reqWrapper; +} + +inline std::unique_ptr OGBinaryReplayGenerator::getReqInternalZstd() { + auto reqWrapper = std::make_unique(); + size_t chunkSize = 1 * 1024 * 1024; // 1 MB + + do { + OracleGeneralBinRequest req; + if (!zstdReader_.read_one_req(&req)) { + throw EndOfTrace("EOF reached"); + } + + reqWrapper->key_ = std::to_string(req.objId); + if (config_.ignoreLargeReq && + (req.objSize + reqWrapper->key_.length() + 32) >= maxSlabSize) { + return getReqInternalZstd(); + } + + if (req.objSize > chunkSize) { + size_t numChunks = (req.objSize + chunkSize - 1) / chunkSize; + reqWrapper->sizes_.clear(); + reqWrapper->sizes_.reserve(numChunks); + for (size_t i = 0; i < numChunks; ++i) { + size_t currentChunkSize = (i == numChunks - 1) + ? (req.objSize % chunkSize == 0 ? chunkSize : req.objSize % chunkSize) + : chunkSize; + reqWrapper->sizes_.push_back(currentChunkSize); + } + + reqWrapper->req_.setOp(OpType::kAddChained); + } else { + reqWrapper->sizes_.clear(); + reqWrapper->sizes_.resize(1); + reqWrapper->sizes_[0] = req.objSize; + reqWrapper->req_.setOp(OpType::kGet); + } + + reqWrapper->req_.sizeBegin = reqWrapper->sizes_.begin(); + reqWrapper->req_.sizeEnd = reqWrapper->sizes_.end(); + + reqWrapper->req_.timestamp = req.clockTime; + reqWrapper->repeats_ = 1; + parseSuccess++; + } while (reqWrapper->repeats_ == 0); + + return reqWrapper; +} + +inline void OGBinaryReplayGenerator::genRequests() { + while (!shouldShutdown()) { + std::unique_ptr reqWrapper; + try { + if(config_.zstdTrace){ + reqWrapper = getReqInternalZstd(); + } else { + reqWrapper = getReqInternal(); + } + } catch (const EndOfTrace&) { + break; + } + + for (size_t keySuffix = 0; keySuffix < ampFactor_; keySuffix++) { + std::unique_ptr req; + // Use a copy of ReqWrapper except for the last one + if (keySuffix == ampFactor_ - 1) { + req.swap(reqWrapper); + } else { + req = std::make_unique(*reqWrapper); + } + + if (ampFactor_ > 1) { + // Replace the last 4 bytes with thread Id of 4 decimal chars. In doing + // so, keep at least 10B from the key for uniqueness; 10B is the max + // number of decimal digits for uint32_t which is used to encode the key + if (req->key_.size() > kMinKeySize) { + // trunkcate the key + size_t newSize = std::max(req->key_.size() - 4, kMinKeySize); + req->key_.resize(newSize, '0'); + } + req->key_.append(folly::sformat("{:04d}", keySuffix)); + } + + auto shardId = getShard(req->req_.key); + auto& stressorCtx = getStressorCtx(shardId); + auto& reqQ = *stressorCtx.reqQueue_; + + while (!reqQ.write(std::move(req)) && !stressorCtx.isFinished() && + !shouldShutdown()) { + // ProducerConsumerQueue does not support blocking, so use sleep + std::this_thread::sleep_for( + std::chrono::microseconds{checkIntervalUs_}); + } + } + } + + setEOF(); +} + +const Request& OGBinaryReplayGenerator::getReq(uint8_t, + std::mt19937_64&, + std::optional) { + std::unique_ptr reqWrapper; + + auto& stressorCtx = getStressorCtx(); + auto& reqQ = *stressorCtx.reqQueue_; + auto& resubmitQueue = stressorCtx.resubmitQueue_; + + while (resubmitQueue.empty() && !reqQ.read(reqWrapper)) { + if (resubmitQueue.empty() && isEOF()) { + throw cachelib::cachebench::EndOfTrace("Test stopped or EOF reached"); + } + // ProducerConsumerQueue does not support blocking, so use sleep + std::this_thread::sleep_for(std::chrono::microseconds{checkIntervalUs_}); + } + + if (!reqWrapper) { + XCHECK(!resubmitQueue.empty()); + reqWrapper.swap(resubmitQueue.front()); + resubmitQueue.pop(); + } + + OGReqWrapper* reqPtr = reqWrapper.release(); + return reqPtr->req_; +} + +void OGBinaryReplayGenerator::notifyResult(uint64_t requestId, OpResultType) { + // requestId should point to the OGReqWrapper object. The ownership is taken + // here to do the clean-up properly if not resubmitted + std::unique_ptr reqWrapper( + reinterpret_cast(requestId)); + XCHECK_GT(reqWrapper->repeats_, 0u); + if (--reqWrapper->repeats_ == 0) { + return; + } + // need to insert into the queue again + getStressorCtx().resubmitQueue_.emplace(std::move(reqWrapper)); +} + +} // namespace cachebench +} // namespace cachelib +} // namespace facebook diff --git a/cachelib/cachebench/workload/ZstdReader.h b/cachelib/cachebench/workload/ZstdReader.h new file mode 100644 index 000000000..4c5ab6c6b --- /dev/null +++ b/cachelib/cachebench/workload/ZstdReader.h @@ -0,0 +1,246 @@ +#pragma once + +#include +#include +#include +#include +#include +#include // For std::memmove +#include // Include folly logging + +namespace facebook { +namespace cachelib { +namespace cachebench { + +struct OracleGeneralBinRequest { + uint32_t clockTime; + uint64_t objId; + uint32_t objSize; + int64_t nextAccessVtime; +}; + +class ZstdReader { +public: + ZstdReader(bool compressed = true); + ~ZstdReader(); + + ZstdReader(ZstdReader&& other) noexcept; + ZstdReader& operator=(ZstdReader&& other) noexcept; + + void open(const std::string& trace_path, bool compressed = true); + size_t read_bytes(size_t n_byte, char** data_start); + bool read_one_req(OracleGeneralBinRequest* req); + void close(); + bool is_open() const; + +private: + std::ifstream ifile; + bool compressed_{true}; // <--- NEW FIELD + + // Only used if compressed_ + std::unique_ptr zds; + std::vector buff_in; + std::vector buff_out; + size_t buff_out_read_pos; + ZSTD_inBuffer input; + ZSTD_outBuffer output; + + enum class Status { OK, ERR, MY_EOF } status; + + size_t read_from_file(); + Status decompress_from_buff(); +}; + +// --- Implementation --- + +ZstdReader::ZstdReader(bool compressed) + : compressed_(compressed), + zds(compressed ? ZSTD_createDStream() : nullptr, ZSTD_freeDStream), + buff_in(compressed ? ZSTD_DStreamInSize() : 0), + buff_out(compressed ? ZSTD_DStreamOutSize() * 2 : 0), + buff_out_read_pos(0), + status(Status::OK) { + if (compressed_) { + input.src = buff_in.data(); + input.size = 0; + input.pos = 0; + output.dst = buff_out.data(); + output.size = buff_out.size(); + output.pos = 0; + ZSTD_initDStream(zds.get()); + } +} + +ZstdReader::~ZstdReader() { + close(); +} + +ZstdReader::ZstdReader(ZstdReader&& other) noexcept + : ifile(std::move(other.ifile)), + compressed_(other.compressed_), + zds(std::move(other.zds)), + buff_in(std::move(other.buff_in)), + buff_out(std::move(other.buff_out)), + buff_out_read_pos(other.buff_out_read_pos), + input(other.input), + output(other.output), + status(other.status) { + other.buff_out_read_pos = 0; + other.input = {nullptr, 0, 0}; + other.output = {nullptr, 0, 0}; + other.status = Status::ERR; +} + +ZstdReader& ZstdReader::operator=(ZstdReader&& other) noexcept { + if (this != &other) { + close(); + ifile = std::move(other.ifile); + compressed_ = other.compressed_; + zds = std::move(other.zds); + buff_in = std::move(other.buff_in); + buff_out = std::move(other.buff_out); + buff_out_read_pos = other.buff_out_read_pos; + input = other.input; + output = other.output; + status = other.status; + other.buff_out_read_pos = 0; + other.input = {nullptr, 0, 0}; + other.output = {nullptr, 0, 0}; + other.status = Status::ERR; + } + return *this; +} + +void ZstdReader::open(const std::string& trace_path, bool compressed) { + compressed_ = compressed; + ifile.open(trace_path, std::ios::binary); + if (!ifile.is_open()) { + throw std::runtime_error("Cannot open file: " + trace_path); + } + if (compressed_) { + if (!zds) { + zds.reset(ZSTD_createDStream()); + } + if (buff_in.empty()) buff_in.resize(ZSTD_DStreamInSize()); + if (buff_out.empty()) buff_out.resize(ZSTD_DStreamOutSize() * 2); + buff_out_read_pos = 0; + input.src = buff_in.data(); + input.size = 0; + input.pos = 0; + output.dst = buff_out.data(); + output.size = buff_out.size(); + output.pos = 0; + ZSTD_initDStream(zds.get()); + } +} + +void ZstdReader::close() { + if (ifile.is_open()) { + ifile.close(); + } +} + +bool ZstdReader::is_open() const { + return ifile.is_open(); +} + +size_t ZstdReader::read_from_file() { + ifile.read(buff_in.data(), buff_in.size()); + size_t read_sz = ifile.gcount(); + if (read_sz < buff_in.size()) { + if (ifile.eof()) { + status = Status::MY_EOF; + } else { + status = Status::ERR; + return 0; + } + } + input.size = read_sz; + input.pos = 0; + return read_sz; +} + +ZstdReader::Status ZstdReader::decompress_from_buff() { + std::memmove(buff_out.data(), buff_out.data() + buff_out_read_pos, output.pos - buff_out_read_pos); + output.pos -= buff_out_read_pos; + buff_out_read_pos = 0; + + if (input.pos >= input.size) { + size_t read_sz = read_from_file(); + if (read_sz == 0) { + if (status == Status::MY_EOF) { + return Status::MY_EOF; + } else { + XLOG(ERR) << "Read from file error"; + return Status::ERR; + } + } + } + + size_t const ret = ZSTD_decompressStream(zds.get(), &output, &input); + if (ret != 0 && ZSTD_isError(ret)) { + XLOG(ERR) << "Zstd decompression error: " << ZSTD_getErrorName(ret); + } + + return Status::OK; +} + +size_t ZstdReader::read_bytes(size_t n_byte, char** data_start) { + if (!compressed_) { + // Uncompressed: read directly from file + static std::vector plain_buff; + if (plain_buff.size() < n_byte) plain_buff.resize(n_byte); + ifile.read(plain_buff.data(), n_byte); + size_t bytes_read = ifile.gcount(); + if (bytes_read != n_byte) { + return 0; + } + *data_start = plain_buff.data(); + return bytes_read; + } + + // Compressed: use existing logic + size_t sz = 0; + while (buff_out_read_pos + n_byte > output.pos) { + Status status = decompress_from_buff(); + if (status != Status::OK) { + if (status != Status::MY_EOF) { + XLOG(ERR) << "Error decompressing file"; + } else { + return 0; + } + break; + } + } + if (buff_out_read_pos + n_byte <= output.pos) { + sz = n_byte; + *data_start = buff_out.data() + buff_out_read_pos; + buff_out_read_pos += n_byte; + } else { + XLOG(ERR) << "Do not have enough bytes " << output.pos - buff_out_read_pos << " < " << n_byte; + } + return sz; +} + +bool ZstdReader::read_one_req(OracleGeneralBinRequest* req) { + char* record; + size_t bytes_read = read_bytes(24, &record); + if (bytes_read != 24) { + return false; + } + req->clockTime = *(uint32_t*)record; + req->objId = *(uint64_t*)(record + 4); + req->objSize = *(uint32_t*)(record + 12); + req->nextAccessVtime = *(int64_t*)(record + 16); + if (req->nextAccessVtime == -1 || req->nextAccessVtime == INT64_MAX) { + req->nextAccessVtime = INT64_MAX; + } + if (req->objSize == 0) { + return read_one_req(req); + } + return true; +} + +} // namespace cachebench +} // namespace cachelib +} // namespace facebook \ No newline at end of file