@@ -1561,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
15611561
15621562 const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
15631563
1564+ slot_info sinfo;
1565+
15641566 bool res = true ;
1565- res = res && state_read_meta (io, strm, cell_count, seq_id);
1566- res = res && state_read_data (io, strm, cell_count);
1567+ res = res && state_read_meta (io, strm, cell_count, sinfo, seq_id);
1568+ res = res && state_read_data (io, strm, cell_count, sinfo );
15671569
15681570 if (!res) {
15691571 if (seq_id == -1 ) {
@@ -1702,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
17021704 }
17031705}
17041706
1705- bool llama_kv_cache::state_read_meta (llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
1707+ bool llama_kv_cache::state_read_meta (llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
17061708 auto & cells = v_cells[strm];
17071709 auto & head = v_heads[strm];
17081710
@@ -1739,7 +1741,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
17391741 ubatch.seq_id [i] = &dest_seq_id;
17401742 }
17411743
1742- const auto sinfo = find_slot (ubatch, true );
1744+ sinfo = find_slot (ubatch, false );
17431745 if (sinfo.empty ()) {
17441746 LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
17451747 return false ;
@@ -1749,20 +1751,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
17491751 // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
17501752 apply_ubatch (sinfo, ubatch);
17511753
1752- const auto head_cur = sinfo.head ();
1753-
1754- // keep the head at the old position because we will read the KV data into it in state_read_data()
1755- head = head_cur;
1756-
1757- LLAMA_LOG_DEBUG (" %s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n " , __func__, head_cur, head, cell_count, dest_seq_id);
1754+ LLAMA_LOG_DEBUG (" %s: cell_count = %d, dest_seq_id = %d\n " , __func__, cell_count, dest_seq_id);
17581755
1759- // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1760- // Assume that this is one contiguous block of cells
1761- GGML_ASSERT (head_cur + cell_count <= cells.size ());
1762- GGML_ASSERT (cells.pos_get (head_cur) == ubatch.pos [0 ]);
1763- GGML_ASSERT (cells.pos_get (head_cur + cell_count - 1 ) == ubatch.pos [cell_count - 1 ]);
1764- GGML_ASSERT (cells.seq_has (head_cur, dest_seq_id));
1765- GGML_ASSERT (cells.seq_has (head_cur + cell_count - 1 , dest_seq_id));
1756+ // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
1757+ GGML_ASSERT (sinfo.n_stream () == 1 );
1758+ GGML_ASSERT (sinfo.idxs [0 ].size () == cell_count);
1759+ for (uint32_t i = 0 ; i < cell_count; ++i) {
1760+ const uint32_t idx = sinfo.idxs [0 ][i];
1761+ GGML_ASSERT (cells.pos_get (idx) == ubatch.pos [i]);
1762+ GGML_ASSERT (cells.seq_has (idx, dest_seq_id));
1763+ }
17661764 } else {
17671765 // whole KV cache restore
17681766
@@ -1795,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
17951793 }
17961794 }
17971795
1796+ // Create contiguous slot_info for whole cache restore
1797+ sinfo.s0 = strm;
1798+ sinfo.s1 = strm;
1799+ sinfo.resize (1 );
1800+ sinfo.strm [0 ] = strm;
1801+ sinfo.idxs [0 ].resize (cell_count);
1802+ for (uint32_t i = 0 ; i < cell_count; ++i) {
1803+ sinfo.idxs [0 ][i] = i;
1804+ }
1805+
17981806 head = 0 ;
17991807 }
18001808
18011809 return true ;
18021810}
18031811
1804- bool llama_kv_cache::state_read_data (llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
1812+ bool llama_kv_cache::state_read_data (llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo ) {
18051813 auto & cells = v_cells[strm];
1806- auto & head = v_heads[strm];
18071814
18081815 uint32_t v_trans;
18091816 uint32_t n_layer;
@@ -1853,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
18531860 }
18541861
18551862 if (cell_count) {
1856- // Read and set the keys for the whole cell range
1857- ggml_backend_tensor_set (k, io.read (cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1863+ if (sinfo.is_contiguous ()) {
1864+ // Fast path: contiguous cells, single memcpy
1865+ ggml_backend_tensor_set (k, io.read (cell_count * k_size_row), sinfo.head () * k_size_row, cell_count * k_size_row);
1866+ } else {
1867+ // Slow path: scatter to non-contiguous positions
1868+ const void * src = io.read (cell_count * k_size_row);
1869+ for (uint32_t i = 0 ; i < cell_count; ++i) {
1870+ const size_t dst_offset = sinfo.idxs [0 ][i] * k_size_row;
1871+ ggml_backend_tensor_set (k, (const char *)src + i * k_size_row, dst_offset, k_size_row);
1872+ }
1873+ }
18581874 }
18591875 }
18601876
@@ -1885,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
18851901 }
18861902
18871903 if (cell_count) {
1888- // Read and set the values for the whole cell range
1889- ggml_backend_tensor_set (v, io.read (cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1904+ if (sinfo.is_contiguous ()) {
1905+ // Fast path: contiguous cells, single memcpy
1906+ ggml_backend_tensor_set (v, io.read (cell_count * v_size_row), sinfo.head () * v_size_row, cell_count * v_size_row);
1907+ } else {
1908+ // Slow path: scatter to non-contiguous positions
1909+ const void * src = io.read (cell_count * v_size_row);
1910+ for (uint32_t i = 0 ; i < cell_count; ++i) {
1911+ const size_t dst_offset = sinfo.idxs [0 ][i] * v_size_row;
1912+ ggml_backend_tensor_set (v, (const char *)src + i * v_size_row, dst_offset, v_size_row);
1913+ }
1914+ }
18901915 }
18911916 }
18921917 } else {
@@ -1925,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
19251950 }
19261951
19271952 if (cell_count) {
1928- // For each row in the transposed matrix, read the values for the whole cell range
1929- for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
1930- const size_t dst_offset = (head + j * cells.size ()) * v_size_el;
1931- ggml_backend_tensor_set (v, io.read (cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1953+ if (sinfo.is_contiguous ()) {
1954+ // Fast path: contiguous cells
1955+ const uint32_t h = sinfo.head ();
1956+ for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
1957+ const size_t dst_offset = (h + j * cells.size ()) * v_size_el;
1958+ ggml_backend_tensor_set (v, io.read (cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1959+ }
1960+ } else {
1961+ // Slow path: scatter to non-contiguous positions
1962+ for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
1963+ const void * src = io.read (cell_count * v_size_el);
1964+ for (uint32_t i = 0 ; i < cell_count; ++i) {
1965+ const size_t dst_offset = (sinfo.idxs [0 ][i] + j * cells.size ()) * v_size_el;
1966+ ggml_backend_tensor_set (v, (const char *)src + i * v_size_el, dst_offset, v_size_el);
1967+ }
1968+ }
19321969 }
19331970 }
19341971 }
0 commit comments