Skip to content

Commit

Permalink
Token rotation (#987)
Browse files Browse the repository at this point in the history
Ticket: 153791

To be merged after:
openvinotoolkit/openvino#27088
  • Loading branch information
vshampor authored Jan 24, 2025
1 parent ec7888c commit 9caf53c
Show file tree
Hide file tree
Showing 25 changed files with 1,148 additions and 195 deletions.
12 changes: 9 additions & 3 deletions src/cpp/include/openvino/genai/cache_eviction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace ov::genai {
class CacheEvictionConfig {
public:
CacheEvictionConfig() {};
CacheEvictionConfig(size_t start_size, size_t recent_size, size_t max_cache_size, AggregationMode aggregation_mode_) : aggregation_mode(aggregation_mode_), m_start_size(start_size), m_recent_size(recent_size), m_max_cache_size(max_cache_size) {
CacheEvictionConfig(size_t start_size, size_t recent_size, size_t max_cache_size, AggregationMode aggregation_mode_, bool apply_rotation_ = false) : aggregation_mode(aggregation_mode_), apply_rotation(apply_rotation_), m_start_size(start_size), m_recent_size(recent_size), m_max_cache_size(max_cache_size) {
OPENVINO_ASSERT(start_size, "CacheEvictionConfig.start_size must be non-zero");
OPENVINO_ASSERT(recent_size, "CacheEvictionConfig.recent_size must be non-zero");
OPENVINO_ASSERT(max_cache_size, "CacheEvictionConfig.max_cache_size must be non-zero");
Expand Down Expand Up @@ -60,6 +60,12 @@ namespace ov::genai {

/** The mode used to compute the importance of tokens for eviction */
AggregationMode aggregation_mode = AggregationMode::NORM_SUM;

/** Whether to apply cache rotation (RoPE-based) after each eviction.
* Set this to false if your model has different RoPE scheme from the one used in the
* original llama model and you experience accuracy issues with cache eviction enabled
* and apply_rotation=true.**/
bool apply_rotation = false;
private:
/** Number of tokens in the *beginning* of KV cache that should be retained
* in the KV cache for this sequence during generation. Must be non-zero and a multiple of the KV cache block size for
Expand All @@ -72,12 +78,12 @@ namespace ov::genai {
std::size_t m_recent_size = 128;

/**
* @brief Maximum cache size (in tokens) that can be occupied by a sequence with cache eviction enabled.
* Maximum cache size (in tokens) that can be occupied by a sequence with cache eviction enabled.
* Actual occupied size may differ from this by no larger than (block_size) tokens.
* Eviction area is computed from this size and the "start"/"recent" area sizes.
* @return Total cache size (in tokens) allowed to be occupied by a sequence.
*/
std::size_t m_max_cache_size = 672;
std::size_t m_evictable_size = 512;

};
}
95 changes: 95 additions & 0 deletions src/cpp/src/cache_eviction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,99 @@ namespace ov::genai {
m_scores[decoder_layer_idx] = new_scores;
m_cache_counter[decoder_layer_idx] = new_counter;
}

CacheRotationCalculator::CacheRotationCalculator(size_t block_size,
size_t max_context_length_in_blocks,
size_t kv_head_size,
double rope_theta)
: m_block_size(block_size),
m_head_size(kv_head_size) {
// Frequencies follow the original recipe from RoFormer:
// https://arxiv.org/pdf/2104.09864v5
//
// However, the way the rotation coefficients are ultimately applied in Llama and related models from
// huggingface is very different from the RoFormer - the embedding-dimension coefficients are not treated as
// consecutive x-y coordinate pairs, but are rather divided into contiguous x-like and y-like halves - see
// `rotate_half` function in HF transformers. It can be shown that this form still preserves the relative
// positioning property from the RoFormer article.
OPENVINO_ASSERT(rope_theta > 0, "rope_theta must be positive");
size_t num_freqs = kv_head_size / 2;
m_rope_sin_lut.resize(max_context_length_in_blocks);
m_rope_cos_lut.resize(max_context_length_in_blocks);

for (size_t i = 0; i < max_context_length_in_blocks; i++) {
m_rope_sin_lut[i].reserve(num_freqs);
m_rope_cos_lut[i].reserve(num_freqs);
for (size_t j = 0; j < num_freqs; j++) {
double exponent = -static_cast<double>(2 * j) / kv_head_size;
double base_angle = std::pow(rope_theta, exponent);
m_rope_sin_lut[i].push_back(
-std::sin(i * block_size * base_angle)); // minus since we will be rotating by an inverse angle
m_rope_cos_lut[i].push_back(std::cos(i * block_size * base_angle));
}
}
}

const std::vector<std::vector<float>>& CacheRotationCalculator::get_sin_lut() const {
return m_rope_sin_lut;
}

const std::vector<std::vector<float>>& CacheRotationCalculator::get_cos_lut() const {
return m_rope_cos_lut;
}

std::vector<CacheRotationCalculator::BlockRotationData> CacheRotationCalculator::get_rotation_data(
const std::set<size_t>& evicted_block_logical_indices,
size_t num_logical_blocks_before_eviction,
bool deltas_only) {


std::vector<BlockRotationData> retval;
if (evicted_block_logical_indices.empty()) {
return retval;
}

for (auto idx : evicted_block_logical_indices) {
OPENVINO_ASSERT(idx < num_logical_blocks_before_eviction);
}

// num_logical_blocks_before_eviction > evicted_block_logical_indices.size() is automatically guaranteed by the
// set property and the previous assertion
retval.reserve(num_logical_blocks_before_eviction - evicted_block_logical_indices.size());

ptrdiff_t current_rotation_delta_in_blocks = 0;
std::vector<size_t> logical_block_space(num_logical_blocks_before_eviction);
std::iota(logical_block_space.begin(), logical_block_space.end(), 0);

for (size_t logical_block_idx : logical_block_space) {
if (evicted_block_logical_indices.find(logical_block_idx) != evicted_block_logical_indices.end()) {
current_rotation_delta_in_blocks += 1;
} else {
if (current_rotation_delta_in_blocks != 0) {
BlockRotationData block_rotation_data;
block_rotation_data.logical_block_idx = logical_block_idx - current_rotation_delta_in_blocks;

// rotation delta is in tokens, but LUT is in blocks right now since we evict per-block
// delta recomputation to a valid LUT index is done at a later stage
block_rotation_data.rotation_delta = current_rotation_delta_in_blocks * m_block_size;
OPENVINO_ASSERT(block_rotation_data.rotation_delta / m_block_size <= m_rope_cos_lut.size(), "rotation delta larger than LUT size");

if (!deltas_only) {
block_rotation_data.cosines.reserve(m_block_size);
block_rotation_data.sines.reserve(m_block_size);
for (size_t i = 0; i < m_block_size; i++) {
block_rotation_data.cosines.push_back(
m_rope_cos_lut[current_rotation_delta_in_blocks]);
block_rotation_data.sines.push_back(
m_rope_sin_lut[current_rotation_delta_in_blocks]);
}
}

retval.push_back(block_rotation_data);
}
}
}

return retval;
}
}
84 changes: 84 additions & 0 deletions src/cpp/src/cache_eviction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,88 @@ class CacheEvictionAlgorithm {
std::vector<std::vector<size_t>> m_cache_counter;
};

/**
* @brief Computes, based on the logical indices of the blocks to be evicted, the rotation coefficients for the
* remaining cache blocks.
*
* The rotation assumes that the executed model applies rotary positional embedding (RoPE) during the execution of
* the attention operation. Each cache block therefore has the RoPE values already "baked in", with positions equivalent
* to the point in time when the cache block values were originally computed in one of the previous attention
* operations. When blocks are evicted, the logical index space of the remaining blocks is in general no longer
* contiguous with respect to the effective positions of tokens in the blocks. Cache rotation allows to remedy this by
* effectively adjusting the RoPE positions of certain blocks in the cache after eviction, by additionally "rotating"
* them (in the same sense as in RoPE) by such angles that the cache blocks in the logical index space are again
* contiguous in terms of the RoPE positions. This is supposed to make the eviction process less impactful on the
* accuracy of the generation.
*
* Currently only the basic RoPE method is supported (as applied in the Llama original models). Each model in general
* may have its own RoPE method (e.g. non-linear/NTK frequency scaling), and ideally the cache rotation calculator
* should be adjusted based on the specifics of the RoPE defined by the LLM.
*/
class CacheRotationCalculator {
public:
/**
* Constructs a CacheRotationCalculator.
* @param block_size Block size of the KV cache to evict from.
* @param max_context_length Maximum length possible for a sequence in the current pipeline.
* @param kv_head_size The size (in elements) of the embedding dimension in the attention operation.
* @param rope_theta The base RoPE angle used in the original LLM.
*/
CacheRotationCalculator(size_t block_size,
size_t max_context_length,
size_t kv_head_size,
double rope_theta = 10000.0f);

using RotationCoefficientsPerToken = std::vector<std::vector<float>>; // dimensions: [BLOCK_SIZE, head_size / 2]

/**
* Basic output structure for the calculator.
*/
struct BlockRotationData {
bool operator==(const BlockRotationData& rhs) const {
return (logical_block_idx == rhs.logical_block_idx) && (sines == rhs.sines) && (cosines == rhs.cosines);
}
size_t logical_block_idx; /** Logical index of the block AFTER eviction to which the rotation
should be applied */
size_t rotation_delta; /** Delta, in token positions, that should be applied to block contents
via rotation **/

// Fields below are currently only used for testing purposes
RotationCoefficientsPerToken sines; /** The sine coefficients to be applied to this block's contents for
rotation, in order of the block's elements */
RotationCoefficientsPerToken cosines; /** The cosine coefficients to be applied to this block's contents for
rotation, in order of the block's elements */
};

/**
* Computes the rotation coefficients for the given state of the logical block space when eviction is about to take
* place.
* @param evicted_block_logical_indices The logical block indices that the prior cache eviction algorithm step
* determined to be necessary to evict.
* @param num_logical_blocks_before_eviction Number of logical blocks that the evicted-from sequence occupied before
* the eviction step.
* @param deltas_only If true, the sines and cosines fields in each returned BlockRotationData will be left empty.
* @return A vector of per-block rotation data, including the indices of blocks after eviction that should be
* rotated, and the pre-computed trigonometric coefficients necessary for rotation.
*/
std::vector<BlockRotationData> get_rotation_data(const std::set<size_t>& evicted_block_logical_indices,
size_t num_logical_blocks_before_eviction,
bool deltas_only = true);

/**
* @return The size of the embedding dimension that this CacheRotationCalculator was initialized with.
*/
size_t get_head_size() const {
return m_head_size;
}

const std::vector<std::vector<float>>& get_sin_lut() const;
const std::vector<std::vector<float>>& get_cos_lut() const;

private:
size_t m_block_size;
size_t m_head_size;
std::vector<std::vector<float>> m_rope_sin_lut; // dimensions: [ max_context_length, head_size / 2]
std::vector<std::vector<float>> m_rope_cos_lut; // dimensions: [ max_context_length, head_size / 2]
};
}
2 changes: 1 addition & 1 deletion src/cpp/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class CacheManager {
std::vector<ov::Tensor> m_key_cache;
std::vector<ov::Tensor> m_value_cache;
size_t m_num_allocated_kv_blocks = 0;
ov::Core m_core;
ov::InferRequest m_request;
ov::Core m_core;

ov::Shape set_first_dim_and_make_static(const ov::PartialShape& shape, size_t dim) {
ov::PartialShape res_shape = shape;
Expand Down
Loading

0 comments on commit 9caf53c

Please sign in to comment.