Thursday, February 19, 2026

Building a Production-Ready Inference Cache with Redis for LLM KV Management

Building a Production-Ready Inference Cache with Redis for LLM KV Management

What You'll Build

By the end of this tutorial, you'll have a working KV (key-value) cache system using Redis to store and retrieve LLM inference results. This dramatically reduces latency for repeated inference requests—think chatbot conversations where context gets reused, or RAG systems hitting the same documents.

You'll build a Python service that intercepts LLM inference calls, checks Redis for cached results, and only hits your expensive GPU inference when there's a cache miss. This pattern can reduce repeated inference latency by orders of magnitude for conversational workloads, turning multi-second responses into millisecond lookups.

Why this matters: as inference workloads scale, you can't just throw more GPUs at the problem. Caching is how engineering teams manage inference costs without sacrificing response times.

Prerequisites

  • Python 3.10+ installed (3.10, 3.11, or 3.12 recommended)
  • Docker 20.x or later for running Redis
  • pip package manager
  • At least 8GB RAM (16GB recommended if running models locally)
  • Basic familiarity with Python and command line
  • Estimated time: 45-60 minutes

Install Python dependencies:

pip install torch transformers redis numpy

Verify installations:

python -c "import torch, transformers, redis; print('All packages installed')"

Step-by-Step Instructions

Step 1: Start Redis with Persistence

Run Redis in Docker with volume mounting so your cache survives restarts:

docker run -d \
  --name inference-cache \
  -p 6379:6379 \
  -v redis-data:/data \
  redis:7.2-alpine redis-server --appendonly yes

What this does:

  • -d: Runs container in detached mode (background)
  • --name inference-cache: Names the container for easy reference
  • -p 6379:6379: Maps Redis default port to your host
  • -v redis-data:/data: Creates persistent volume for cache data
  • --appendonly yes: Enables AOF persistence (writes survive restarts)

Verify Redis is running:

docker logs inference-cache | grep -i "ready to accept"

You should see output indicating Redis is ready to accept connections.

Step 2: Create the KV Cache Manager

Create a file called kv_cache_manager.py. This handles serialization of inference results into Redis-friendly byte strings and manages cache keys with TTL (time-to-live).

import redis
import numpy as np
import hashlib
import pickle
from typing import Optional, Tuple

class KVCacheManager:
    def __init__(self, host='localhost', port=6379, ttl=3600):
        """
        Initialize Redis connection with TTL for cache entries.
        
        Args:
            host: Redis server hostname
            port: Redis server port
            ttl: Time-to-live in seconds (default: 3600 = 1 hour)
        """
        self.redis_client = redis.Redis(
            host=host, 
            port=port, 
            decode_responses=False  # Store binary data
        )
        self.ttl = ttl
        
    def _generate_key(self, prompt: str, layer_idx: int) -> str:
        """
        Generate cache key from prompt + layer index.
        Uses SHA256 hash to keep keys manageable length.
        
        Args:
            prompt: Input text prompt
            layer_idx: Layer index (-1 for final output)
            
        Returns:
            Redis key string like "kv:layer-1:a3f7c8b4e9d2c1f5"
        """
        prompt_hash = hashlib.sha256(prompt.encode()).hexdigest()[:16]
        return f"kv:layer{layer_idx}:{prompt_hash}"
    
    def store_kv(self, prompt: str, layer_idx: int, 
                 key_cache: np.ndarray, value_cache: np.ndarray):
        """
        Store key and value tensors for a specific layer.
        
        Args:
            prompt: Input prompt used to generate cache key
            layer_idx: Layer index for this KV pair
            key_cache: Numpy array representing key tensor
            value_cache: Numpy array representing value tensor
        """
        cache_key = self._generate_key(prompt, layer_idx)
        
        # Serialize numpy arrays using pickle
        data = pickle.dumps({
            'key': key_cache,
            'value': value_cache
        })
        
        # Store with TTL to prevent unbounded memory growth
        self.redis_client.setex(cache_key, self.ttl, data)
        
    def retrieve_kv(self, prompt: str, layer_idx: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
        """
        Retrieve cached KV pairs.
        
        Args:
            prompt: Input prompt to look up
            layer_idx: Layer index to retrieve
            
        Returns:
            Tuple of (key_cache, value_cache) if found, None otherwise
        """
        cache_key = self._generate_key(prompt, layer_idx)
        data = self.redis_client.get(cache_key)
        
        if data is None:
            return None
            
        kv_pair = pickle.loads(data)
        return kv_pair['key'], kv_pair['value']
    
    def clear_cache(self):
        """Flush all KV cache entries matching our pattern."""
        for key in self.redis_client.scan_iter("kv:*"):
            self.redis_client.delete(key)

What this does: The cache manager creates unique keys by hashing prompts (to keep key lengths manageable), serializes numpy arrays using pickle, and stores them in Redis with automatic expiration via TTL. This prevents your cache from growing unbounded and consuming all available memory.

Note on pickle security: Pickle has known security vulnerabilities when deserializing untrusted data. For production systems handling untrusted input, use safer serialization formats like msgpack or protobuf.

Step 3: Create the Cached Inference Wrapper

Create cached_inference.py. This wraps a HuggingFace model and intercepts inference calls to check the cache first.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from kv_cache_manager import KVCacheManager
import time

class CachedInferenceModel:
    def __init__(self, model_name: str, cache_manager: KVCacheManager):
        """
        Initialize model with cache support.
        
        Args:
            model_name: HuggingFace model identifier (e.g., 'gpt2')
            cache_manager: KVCacheManager instance for caching
        """
        print(f"Loading tokenizer for {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        print(f"Loading model {model_name}...")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,  # Use half precision to save memory
            device_map='auto'  # Automatically choose CPU/GPU
        )
        
        self.cache_manager = cache_manager
        self.cache_hits = 0
        self.cache_misses = 0
        
    def generate_with_cache(self, prompt: str, max_new_tokens: int = 50) -> str:
        """
        Generate text with caching support.
        
        This implementation caches final outputs based on exact prompt matching.
        For production, you'd cache intermediate KV tensors from attention layers.
        
        Args:
            prompt: Input text prompt
            max_new_tokens: Maximum tokens to generate
            
        Returns:
            Generated text (with [CACHED] prefix if from cache)
        """
        start_time = time.time()
        
        # Check cache first (using layer_idx=-1 to indicate final output)
        cached_output = self.cache_manager.retrieve_kv(prompt, layer_idx=-1)
        
        if cached_output is not None:
            self.cache_hits += 1
            elapsed = time.time() - start_time
            print(f"✓ Cache HIT! Retrieved in {elapsed:.4f}s")
            # Reconstruct output from cached data
            cached_text = cached_output[0].tobytes().decode('utf-8')
            return f"[CACHED] {cached_text}"
        
        # Cache miss - run full inference
        self.cache_misses += 1
        print(f"✗ Cache MISS. Running full inference...")
        
        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        # Generate output
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,  # Deterministic output for caching
                use_cache=True,  # Enable model's internal KV cache
                pad_token_id=self.tokenizer.eos_token_id  # Prevent warnings
            )
        
        # Decode generated tokens
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Store in cache for future requests
        # We store the text as numpy array for consistency with the interface
        import numpy as np
        text_bytes = generated_text.encode('utf-8')
        self.cache_manager.store_kv(
            prompt, 
            layer_idx=-1, 
            key_cache=np.frombuffer(text_bytes, dtype=np.uint8), 
            value_cache=np.array([])  # Empty value cache for this simplified version
        )
        
        elapsed = time.time() - start_time
        print(f"Generated in {elapsed:.4f}s")
        
        return generated_text
    
    def print_stats(self):
        """Print cache performance statistics."""
        total = self.cache_hits + self.cache_misses
        hit_rate = (self.cache_hits / total * 100) if total > 0 else 0
        print(f"\n=== Cache Statistics ===")
        print(f"Hits: {self.cache_hits}")
        print(f"Misses: {self.cache_misses}")
        print(f"Hit Rate: {hit_rate:.1f}%")

What this does: This wrapper checks the cache before running inference. On cache miss, it runs the full model inference, stores the result, and returns it. On cache hit, it returns the cached result immediately—typically orders of magnitude faster than full inference.

Simplified approach: This implementation caches final text outputs rather than intermediate KV tensors from attention layers. Caching actual KV tensors requires modifying the model's forward pass (possible but beyond this tutorial's scope). The caching pattern and performance benefits are identical.

Step 4: Test the Cache System

Create test_cache.py to demonstrate cache hits vs misses:

from kv_cache_manager import KVCacheManager
from cached_inference import CachedInferenceModel

def main():
    # Initialize cache manager
    cache_mgr = KVCacheManager(host='localhost', port=6379, ttl=3600)
    
    # Clear any existing cache for clean test
    print("Clearing cache...")
    cache_mgr.clear_cache()
    
    # Load model (using GPT-2 for speed - works with any causal LM)
    print("\nLoading model...")
    model = CachedInferenceModel('gpt2', cache_mgr)
    
    # Test prompt
    prompt = "The future of AI infrastructure is"
    
    # First run - cache miss expected
    print(f"\n{'='*60}")
    print(f"TEST 1: First inference (cache miss expected)")
    print(f"{'='*60}")
    output1 = model.generate_with_cache(prompt, max_new_tokens=30)
    print(f"Output: {output1[:100]}...")
    
    # Second run - cache hit expected
    print(f"\n{'='*60}")
    print(f"TEST 2: Second inference with same prompt (cache hit expected)")
    print(f"{'='*60}")
    output2 = model.generate_with_cache(prompt, max_new_tokens=30)
    print(f"Output: {output2[:100]}...")
    
    # Third run - different prompt, cache miss expected
    print(f"\n{'='*60}")
    print(f"TEST 3: Different prompt (cache miss expected)")
    print(f"{'='*60}")
    prompt2 = "AI models require"
    output3 = model.generate_with_cache(prompt2, max_new_tokens=30)
    print(f"Output: {output3[:100]}...")
    
    # Fourth run - back to first prompt, cache hit expected
    print(f"\n{'='*60}")
    print(f"TEST 4: Back to first prompt (cache hit expected)")
    print(f"{'='*60}")
    output4 = model.generate_with_cache(prompt, max_new_tokens=30)
    print(f"Output: {output4[:100]}...")
    
    # Print final statistics
    model.print_stats()

if __name__ == "__main__":
    main()

Run the test:

python test_cache.py

Performance analysis: You'll observe dramatic performance differences between cache misses (full inference) and cache hits (Redis lookup). Cache hits typically complete in milliseconds while full inference takes seconds—demonstrating how caching reduces latency for repeated requests. In production systems serving thousands of requests, this translates directly to reduced GPU costs and improved user experience.

Step 5: Monitor Redis Memory Usage

Check current memory usage:

docker exec inference-cache redis-cli INFO memory | grep used_memory_human

For continuous monitoring, create monitor_cache.py:

import redis
import time

def monitor_cache(host='localhost', port=6379, interval=5):
    """
    Monitor Redis cache metrics in real-time.
    
    Args:
        host: Redis hostname
        port: Redis port
        interval: Seconds between updates
    """
    client = redis.Redis(host=host, port=port)
    
    print("Monitoring Redis cache (Ctrl+C to stop)...")
    print(f"{'Time':<20} {'Keys':<10} {'Memory':<15} {'Hit Rate':<10}")
    print("-" * 60)
    
    try:
        while True:
            info = client.info()
            
            # Gather metrics
            keys = client.dbsize()
            memory_mb = info['used_memory'] / (1024 * 1024)
            hits = info.get('keyspace_hits', 0)
            misses = info.get('keyspace_misses', 0)
            total = hits + misses
            hit_rate = (hits / total * 100) if total > 0 else 0
            
            # Display row
            timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
            print(f"{timestamp:<20} {keys:<10} {memory_mb:>10.2f} MB   {hit_rate:>6.1f}%")
            
            time.sleep(interval)
            
    except KeyboardInterrupt:
        print("\nMonitoring stopped.")

if __name__ == "__main__":
    monitor_cache()

Run in a separate terminal while testing:

python monitor_cache.py

This gives you real-time visibility into cache performance and memory consumption—critical for production deployments.

Verification

Confirm everything works correctly with these checks:

1. Verify Redis Container Status

docker ps | grep inference-cache

Container should show "Up" status.

2. Verify Cache Keys Exist

docker exec inference-cache redis-cli KEYS "kv:*"

After running test_cache.py, you should see keys matching the pattern kv:layer-1:{hash}.

3. Test Cache Hit Rate

Run test_cache.py a second time:

python test_cache.py

The second execution should show 100% cache hits for both test prompts (4 hits, 0 misses).

4. Verify TTL is Working

Check time-to-live on a cache key (replace the hash with an actual key from step 2):

docker exec inference-cache redis-cli TTL "kv:layer-1:a3f7c8b4e9d2c1f5"

Should return a positive integer less than 3600 (seconds remaining until expiry). If it returns -1, TTL wasn't set correctly.

5. Test Cache Persistence

Restart Redis and verify cache survives:

# Restart container
docker restart inference-cache

# Wait 5 seconds for startup
sleep 5

# Check if keys still exist
docker exec inference-cache redis-cli KEYS "kv:*"

Keys should still be present, confirming AOF persistence is working.

Troubleshooting

Issue 1: "ConnectionRefusedError: [Errno 111] Connection refused"

Cause: Redis isn't running or isn't accessible on port 6379.

Fix:

# Check if Redis container is running
docker ps -a | grep inference-cache

# If stopped, start it
docker start inference-cache

# If it doesn't exist, recreate it
docker run -d --name inference-cache -p 6379:6379 -v redis-data:/data redis:7.2-alpine redis-server --appendonly yes

# Verify it's accepting connections
docker logs inference-cache | grep -i "ready"

Issue 2: "ModuleNotFoundError: No module named 'transformers'"

Cause: Python dependencies not installed or wrong Python environment active.

Fix:

# Reinstall dependencies
pip install torch transformers redis numpy

# Verify installation
python -c "import transformers; print(transformers.__version__)"

Issue 3: Cache hits not occurring on repeated prompts

Cause: TTL expired, or cache was cleared between runs.

Fix:

# Check if keys exist
docker exec inference-cache redis-cli KEYS "kv:*"

# If no keys, run test_cache.py again
python test_cache.py

# Then immediately run it again to see cache hits
python test_cache.py

Issue 4: "RuntimeError: CUDA out of memory"

Cause: GPU doesn't have enough memory for the model.

Fix:

The code already uses torch.float16 for memory efficiency. If still encountering issues:

# In cached_inference.py, modify model loading:
self.model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map='cpu'  # Force CPU usage
)

Or use a smaller model like distilgpt2 instead of gpt2.

Issue 5: "pickle.UnpicklingError: invalid load key"

Cause: Corrupted cache data or version mismatch between pickle writes and reads.

Fix:

# Clear the cache completely
docker exec inference-cache redis-cli FLUSHDB

# Run test again
python test_cache.py

Next Steps

Now that you have a working inference cache, consider these enhancements:

1. Implement Semantic Caching

Instead of exact prompt matching, use embedding similarity to cache semantically similar prompts. This increases cache hit rates for paraphrased queries.

from sentence_transformers import SentenceTransformer

# Add to KVCacheManager.__init__
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Modify _generate_key to use embeddings
def _generate_key_semantic(self, prompt: str, threshold=0.9):
    embedding = self.embedding_model.encode(prompt)
    # Search for similar cached prompts using cosine similarity
    # Return existing key if similarity > threshold

2. Add Cache Warming

Pre-populate the cache with common queries during deployment:

def warm_cache(model, common_prompts):
    """Pre-cache frequently used prompts."""
    for prompt in common_prompts:
        model.generate_with_cache(prompt)

3. Implement Cache Eviction Policies

Beyond TTL, implement LRU (Least Recently Used) or LFU (Least Frequently Used)

No comments:

Post a Comment