Skip to content

Commit 4529c66

Browse files
ssweensggerganov
andauthored
kv-cache: Fix state restore fragmented cache (#17982)
* kv-cache : fix state restore with fragmented cache (#17527) Change find_slot to allow non-contiguous allocation during state restore. Fixes 'failed to find available cells in kv cache' error when restoring state to fragmented cache. * tests : update logic * cleanup: tightened state_read_meta sig, added is_contiguous case * fix: state_read_meta arg reorder loose ends --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 0f4f35e commit 4529c66

File tree

4 files changed

+213
-29
lines changed

4 files changed

+213
-29
lines changed

src/llama-kv-cache.cpp

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
15611561

15621562
const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
15631563

1564+
slot_info sinfo;
1565+
15641566
bool res = true;
1565-
res = res && state_read_meta(io, strm, cell_count, seq_id);
1566-
res = res && state_read_data(io, strm, cell_count);
1567+
res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
1568+
res = res && state_read_data(io, strm, cell_count, sinfo);
15671569

15681570
if (!res) {
15691571
if (seq_id == -1) {
@@ -1702,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
17021704
}
17031705
}
17041706

1705-
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
1707+
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
17061708
auto & cells = v_cells[strm];
17071709
auto & head = v_heads[strm];
17081710

@@ -1739,7 +1741,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
17391741
ubatch.seq_id[i] = &dest_seq_id;
17401742
}
17411743

1742-
const auto sinfo = find_slot(ubatch, true);
1744+
sinfo = find_slot(ubatch, false);
17431745
if (sinfo.empty()) {
17441746
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
17451747
return false;
@@ -1749,20 +1751,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
17491751
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
17501752
apply_ubatch(sinfo, ubatch);
17511753

1752-
const auto head_cur = sinfo.head();
1753-
1754-
// keep the head at the old position because we will read the KV data into it in state_read_data()
1755-
head = head_cur;
1756-
1757-
LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
1754+
LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
17581755

1759-
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1760-
// Assume that this is one contiguous block of cells
1761-
GGML_ASSERT(head_cur + cell_count <= cells.size());
1762-
GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
1763-
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
1764-
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
1765-
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
1756+
// DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
1757+
GGML_ASSERT(sinfo.n_stream() == 1);
1758+
GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
1759+
for (uint32_t i = 0; i < cell_count; ++i) {
1760+
const uint32_t idx = sinfo.idxs[0][i];
1761+
GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
1762+
GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
1763+
}
17661764
} else {
17671765
// whole KV cache restore
17681766

@@ -1795,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
17951793
}
17961794
}
17971795

1796+
// Create contiguous slot_info for whole cache restore
1797+
sinfo.s0 = strm;
1798+
sinfo.s1 = strm;
1799+
sinfo.resize(1);
1800+
sinfo.strm[0] = strm;
1801+
sinfo.idxs[0].resize(cell_count);
1802+
for (uint32_t i = 0; i < cell_count; ++i) {
1803+
sinfo.idxs[0][i] = i;
1804+
}
1805+
17981806
head = 0;
17991807
}
18001808

18011809
return true;
18021810
}
18031811

1804-
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
1812+
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
18051813
auto & cells = v_cells[strm];
1806-
auto & head = v_heads[strm];
18071814

18081815
uint32_t v_trans;
18091816
uint32_t n_layer;
@@ -1853,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
18531860
}
18541861

18551862
if (cell_count) {
1856-
// Read and set the keys for the whole cell range
1857-
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1863+
if (sinfo.is_contiguous()) {
1864+
// Fast path: contiguous cells, single memcpy
1865+
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
1866+
} else {
1867+
// Slow path: scatter to non-contiguous positions
1868+
const void * src = io.read(cell_count * k_size_row);
1869+
for (uint32_t i = 0; i < cell_count; ++i) {
1870+
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
1871+
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
1872+
}
1873+
}
18581874
}
18591875
}
18601876

@@ -1885,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
18851901
}
18861902

18871903
if (cell_count) {
1888-
// Read and set the values for the whole cell range
1889-
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1904+
if (sinfo.is_contiguous()) {
1905+
// Fast path: contiguous cells, single memcpy
1906+
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
1907+
} else {
1908+
// Slow path: scatter to non-contiguous positions
1909+
const void * src = io.read(cell_count * v_size_row);
1910+
for (uint32_t i = 0; i < cell_count; ++i) {
1911+
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
1912+
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
1913+
}
1914+
}
18901915
}
18911916
}
18921917
} else {
@@ -1925,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
19251950
}
19261951

19271952
if (cell_count) {
1928-
// For each row in the transposed matrix, read the values for the whole cell range
1929-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1930-
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1931-
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1953+
if (sinfo.is_contiguous()) {
1954+
// Fast path: contiguous cells
1955+
const uint32_t h = sinfo.head();
1956+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1957+
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
1958+
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1959+
}
1960+
} else {
1961+
// Slow path: scatter to non-contiguous positions
1962+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1963+
const void * src = io.read(cell_count * v_size_el);
1964+
for (uint32_t i = 0; i < cell_count; ++i) {
1965+
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
1966+
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
1967+
}
1968+
}
19321969
}
19331970
}
19341971
}

src/llama-kv-cache.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ class llama_kv_cache : public llama_memory_i {
7272
void clear() {
7373
idxs.clear();
7474
}
75+
76+
// check if indices are contiguous starting from head()
77+
bool is_contiguous() const {
78+
if (idxs.empty() || idxs[0].empty()) {
79+
return true;
80+
}
81+
if (idxs.size() > 1) {
82+
return false;
83+
}
84+
const uint32_t h = idxs[0][0];
85+
for (size_t i = 0; i < idxs[0].size(); ++i) {
86+
if (idxs[0][i] != h + i) {
87+
return false;
88+
}
89+
}
90+
return true;
91+
}
7592
};
7693

7794
using slot_info_vec_t = std::vector<slot_info>;
@@ -264,8 +281,8 @@ class llama_kv_cache : public llama_memory_i {
264281
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
265282
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
266283

267-
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
268-
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
284+
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1);
285+
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
269286
};
270287

271288
class llama_kv_cache_context : public llama_memory_context_i {

tests/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,14 @@ llama_build_and_test(test-backend-ops.cpp)
222222
llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
223223
llama_build_and_test(test-autorelease.cpp LABEL "model")
224224

225+
# Test for state restore with fragmented KV cache
226+
# Requires a model, uses same args pattern as test-thread-safety
227+
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
228+
llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf)
229+
else()
230+
llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-be.Q4_0.gguf)
231+
endif()
232+
225233
if (NOT GGML_BACKEND_DL)
226234
# these tests use the backends directly and cannot be built with dynamic loading
227235
llama_build_and_test(test-barrier.cpp)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Test for state restore with fragmented KV cache
2+
// This tests the fix for: https://github.com/ggml-org/llama.cpp/issues/17527
3+
// The issue was that state restore required contiguous KV cache slots,
4+
// which fails when the cache is fragmented.
5+
//
6+
// The fix changes find_slot(ubatch, true) to find_slot(ubatch, false)
7+
// in state_read_meta(), allowing non-contiguous slot allocation.
8+
9+
#include "arg.h"
10+
#include "common.h"
11+
#include "llama.h"
12+
13+
#include <vector>
14+
#include <cstdio>
15+
#include <cstring>
16+
17+
int main(int argc, char ** argv) {
18+
common_params params;
19+
20+
params.sampling.seed = 1234;
21+
params.kv_unified = true;
22+
params.n_parallel = 3;
23+
params.n_ctx = 256;
24+
25+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
26+
return 1;
27+
}
28+
29+
common_init();
30+
31+
// init
32+
common_init_result_ptr llama_init = common_init_from_params(params);
33+
34+
llama_model * model = llama_init->model();
35+
llama_context * ctx = llama_init->context();
36+
37+
if (model == nullptr || ctx == nullptr) {
38+
fprintf(stderr, "%s : failed to init\n", __func__);
39+
return 1;
40+
}
41+
42+
GGML_UNUSED(model);
43+
44+
// tokenize prompt
45+
std::vector<llama_token> tokens(70, 1);
46+
47+
// interleave the 3 sequences:
48+
// 01201230123...
49+
llama_batch batch = llama_batch_init(params.n_parallel*tokens.size(), 0, 1);
50+
for (size_t i = 0; i < tokens.size(); i++) {
51+
for (int s = 0; s < params.n_parallel; ++s) {
52+
common_batch_add(batch, tokens[i], i, {s}, false);
53+
}
54+
}
55+
batch.logits[batch.n_tokens - 1] = true;
56+
57+
if (llama_decode(ctx, batch)) {
58+
fprintf(stderr, "%s : failed to decode seq 0\n", __func__);
59+
return 1;
60+
}
61+
62+
fprintf(stderr, "%s : processed prompt on seq 0, 1, 2 (%zu tokens each)\n", __func__, tokens.size());
63+
64+
// Save state of seq 1
65+
std::vector<uint8_t> seq_state(llama_state_seq_get_size(ctx, 1));
66+
const size_t ncopy = llama_state_seq_get_data(ctx, seq_state.data(), seq_state.size(), 1);
67+
if (ncopy != seq_state.size()) {
68+
fprintf(stderr, "%s : failed to save seq 1 state\n", __func__);
69+
return 1;
70+
}
71+
fprintf(stderr, "%s : saved seq 1 state, %zu bytes\n", __func__, ncopy);
72+
73+
// clear seq 1 to create a "hole" in the KV cache (fragmentation)
74+
// 0.20.20.20.2....
75+
llama_memory_t mem = llama_get_memory(ctx);
76+
llama_memory_seq_rm(mem, 1, -1, -1);
77+
fprintf(stderr, "%s : cleared seq 1 to create fragmentation\n", __func__);
78+
79+
// Now the cache has holes where seq 1 was
80+
// This creates fragmentation - there's no contiguous block large enough
81+
// for the seq 1 state if we only look for contiguous slots
82+
83+
// Restore seq 1 state into seq 1 (should work with non-contiguous allocation)
84+
// We use seq 1 since it's a valid sequence ID (0 to n_parallel-1)
85+
// Before the fix, this would fail with "failed to find available cells in kv cache"
86+
const size_t nset = llama_state_seq_set_data(ctx, seq_state.data(), seq_state.size(), 1);
87+
if (nset != seq_state.size()) {
88+
fprintf(stderr, "%s : FAILED to restore seq state into fragmented cache (got %zu, expected %zu)\n",
89+
__func__, nset, seq_state.size());
90+
fprintf(stderr, "%s : This is the bug - state restore fails with fragmented KV cache\n", __func__);
91+
llama_batch_free(batch);
92+
return 1;
93+
}
94+
fprintf(stderr, "%s : restored state into seq 1, %zu bytes\n", __func__, nset);
95+
96+
// Verify we can decode with the restored state
97+
// Generate one token to verify the restored state is usable
98+
auto sparams = llama_sampler_chain_default_params();
99+
llama_sampler * smpl = llama_sampler_chain_init(sparams);
100+
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
101+
102+
auto next_token = llama_sampler_sample(smpl, ctx, -1);
103+
auto next_token_str = common_token_to_piece(ctx, next_token);
104+
105+
common_batch_clear(batch);
106+
common_batch_add(batch, next_token, (int)tokens.size(), {1}, true);
107+
108+
if (llama_decode(ctx, batch)) {
109+
fprintf(stderr, "%s : failed to decode with restored state\n", __func__);
110+
llama_sampler_free(smpl);
111+
llama_batch_free(batch);
112+
return 1;
113+
}
114+
115+
fprintf(stderr, "%s : successfully decoded with restored state, generated: '%s'\n", __func__, next_token_str.c_str());
116+
fprintf(stderr, "%s : SUCCESS - state restore works with fragmented KV cache\n", __func__);
117+
118+
llama_sampler_free(smpl);
119+
llama_batch_free(batch);
120+
121+
return 0;
122+
}

0 commit comments

Comments
 (0)