Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions tensorrt_llm/runtime/kv_cache_manager_v2/_block_radix_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import hashlib
from array import array
from typing import TYPE_CHECKING, Iterable, Iterator, NamedTuple, Sequence, TypeVar, cast

from . import rawref
Expand Down Expand Up @@ -90,11 +91,21 @@ def update(self, data: int | bytes | Sequence[int | bytes]) -> "Hasher":
elif type(data) is bytes:
self._hasher.update(data)
else:
for item in data: # type: ignore
assert (
NDEBUG or (type(item) is int and (0 <= item < (1 << 64))) or type(item) is bytes
)
self._hasher.update(item.to_bytes(8, "little") if (type(item) is int) else item) # type: ignore
# Hash the whole token block in one C call instead of one per token.
# array("Q", data).tobytes() packs each int as 8 native-endian bytes;
# all NVIDIA GPU host platforms (x86_64, aarch64/Grace) are little-endian
# so this is byte-identical to the per-token to_bytes(8, "little") loop.
# Falls back to that loop for multimodal blocks (which contain bytes items).
try:
self._hasher.update(array("Q", data).tobytes()) # type: ignore
except (TypeError, OverflowError):
for item in data: # type: ignore
assert (
NDEBUG
or (type(item) is int and (0 <= item < (1 << 64)))
or type(item) is bytes
)
self._hasher.update(item.to_bytes(8, "little") if (type(item) is int) else item) # type: ignore
return self

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import array
import functools
import gc
import hashlib
import itertools
import os
import random
Expand Down Expand Up @@ -52,7 +53,7 @@
TokenIdExt,
_KVCache,
)
from kv_cache_manager_v2._block_radix_tree import traverse_post_order
from kv_cache_manager_v2._block_radix_tree import Hasher, traverse_post_order
from kv_cache_manager_v2._common import (
BAD_PAGE_INDEX,
GPU_LEVEL,
Expand Down Expand Up @@ -105,7 +106,10 @@
TokenIdExt,
_KVCache,
)
from tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree import traverse_post_order
from tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree import (
Hasher,
traverse_post_order,
)
from tensorrt_llm.runtime.kv_cache_manager_v2._common import (
BAD_PAGE_INDEX,
GPU_LEVEL,
Expand Down Expand Up @@ -2556,5 +2560,33 @@ def test_shrink_touched_pool(self) -> None:
allocator.release(s)


class TestBlockKeyHashing(unittest.TestCase):
"""Verify Hasher.update produces bit-identical digests to the per-token reference (no GPU needed)."""

@staticmethod
def _ref_update(seed: bytes, block: "list[int | bytes]") -> bytes:
h = hashlib.sha256()
h.update(seed)
for item in block:
h.update(item.to_bytes(8, "little") if type(item) is int else item)
return h.digest()

def test_update_int_block_matches_reference(self) -> None:
rng = random.Random(123)
seed = b"\xaa\xbb\xcc"
for n in (0, 1, 7, 32, 33, 257):
block = [rng.randint(0, (1 << 60)) for _ in range(n)]
self.assertEqual(
Hasher(seed).update(block).digest,
self._ref_update(seed, block),
f"int block of length {n}",
)

def test_update_mixed_multimodal_block(self) -> None:
block = [randbytes(32), 5, 6, randbytes(32)] + list(range(20))
seed = b"\x01"
self.assertEqual(Hasher(seed).update(block).digest, self._ref_update(seed, block))


if __name__ == "__main__":
unittest.main()
Loading