Published on

Faster LLM inference with KV cache, speculative decoding and torch.compile

Authors

After successfully implementing Qwen3 model we have to run some inference with it to see if it actually works. Well, after some debugging sessions I got it working. To establish a baseline for optimizations I run the 1.7B model with simple prompt and 200 maximum tokens. I was able to get average 3.63 tokens per second which is rather poor result. It's time to apply some optimizations.

KV-caching

KV caching (Key-Value caching) is an optimization technique used during autoregressive text generation in transformer-based LLMs. It dramatically speeds up inference by avoiding redundant computations of attention keys and values for tokens that have already been processed. When generating text token-by-token (autoregressivly), transformers have a fundamental inefficiency:

Input: "The quick brown fox jumps"
Generate: "over"

Step 1: Process "The" -> generate "quick"
Step 2: Process "The quick" -> generate "brown" (recomputes "The"!)
Step 3: Process "The quick brown" -> generate "fox" (recomputes "The quick"!)
Step 4: Process "The quick brown fox" -> generate "jumps" (recomputes everything!)
Step 5: Process "The quick brown fox jumps" -> generate "over" (recomputes everything!)

Each generation step processes all previous tokens again, leading to O(n2)O(n^2) complexity for generating nn tokens.

How does KV caching works step by step?

  1. First generation - model sees the first input, calculates and stores its keys and values in the cache
  2. For each next token model retrieves stored keys and values and appends new ones instead of starting calculations all over again
  3. Calculate attention using caches Ks and Vs along with new Q (query)
  4. Newly generated token is appended to input and we go back to point 2 until finished

KV cache is easy to implement but the speedups are significant. With it enabled we got 24.02 tokens per second which is about 6 times quicker then inferencing without cache.

class Attention(nn.Module):
    def __init__(self, config: Qwen3Config, layer_index) -> None:
        ...
        self.cache_k = None
        self.cache_v = None
        ...
    
    def forward(self, x: torch.Tensor, is_causal=True, use_cache=True):
        batch, seq_len, dim = x.shape

        q = self.q_proj(x).reshape(batch, seq_len, self.n_heads, self.head_dim)
        k = self.k_proj(x).reshape(batch, seq_len, self.n_kv_heads, self.head_dim)
        v = self.v_proj(x).reshape(batch, seq_len, self.n_kv_heads, self.head_dim)

        q = self.q_norm(q)
        k = self.k_norm(k)

        if use_cache:
            if self.cache_k is None:
                self.cache_k, self.cache_v = k, v
            else:
                self.cache_k = torch.cat([self.cache_k, k], dim=1)
                self.cache_v = torch.cat([self.cache_v, v], dim=1) 
            k, v = self.cache_k, self.cache_v

        q, k = apply_rope(q, k, self.cos_cache, self.sin_cache)
        ...
        # further attention calculations

It is important to get the order of operation right. Qwen3 model first applies RSMNorm and then RoPE embeddings. Applying KV cache happens between those operations.

Speculative decoding

Speculative decoding is a technique that speeds up autoregressive LLM inference by using a small, fast "draft" model to predict multiple tokens ahead, then having the large target model verify them in parallel. It achieves lossless speedup which means that you get identical outputs to standard decoding, just faster. Algorithm go as follows:

  1. Generate kk tokens with draft token and stores probabilities for each of them
Input: "The quick brown"
Draft model generates: "fox jump over a tree."
draft_tokens = [fox, jump, over, a, tree]
draft_probs = [0.95, 0.85, 0.80, 0.6, 0.5]
  1. Use single forward pass with target model to get probabilities for draft tokens
Target model processes "The quick brown [fox] [jumps] [over] [a] [tree.]"
target_probs = [0.90, 0.80, 0.75, 0.2, 0.1]
  1. Compare distributions to see how many right tokens do we have
 draft_probs = [0.95, 0.85, 0.80, 0.6, 0.5]
target_probs = [0.90, 0.80, 0.75, 0.2, 0.1]
is_accepted  = [yes,   yes,  yes,  no,  x ] 
  • We stop accepting tokens after first miss since we know that the rest is dependent on the wrong token
  • At the rejection point we resample with modified probabilities to keep the prediction unbiased.
  • If we get all right we get a "bonus" token form the result of the forward pass of the target model. We append it to generated output.

KV cache plays important role in this technique since it ensures that draft's model single forward pass will be fast and we won't have to calculate each attention score once more.

Implementation of speculative decoding is rather long so I will not post it here as a part of this post, however you can see my implementation here. While implementing you have to be very careful because we have to trim KV cache properly since it is appended with wrong tokens.

Batch inference

Batch inference is the practice of processing multiple independent requests (sequences) simultaneously in a single forward pass through the model. Instead of generating tokens for one user at a time, you generate tokens for many users in parallel.

Sequential (batch size = 1):
A: "Hello" -> forward pass -> "world"
B: "How are" -> forward pass -> "you"  
C: "The cat" -> forward pass -> "sat"

Batched (batch size = 3):
[A: "Hello", B: "How are", C: "The cat"] -> (with single forward pass) -> ["world", "you", "sat"]

The GPU spends most time loading model weights from memory, not doing actual computation. Instead of processing one prompt after another we gather them together and pass data to GPU with one call. This significantly speeds up the results. Variations of that method (such as continuous batching) are mostly used in inference engines which are out of the scope for this post.

Torch compile

torch.compile is PyTorch's JIT (Just-In-Time) compiler introduced in PyTorch 2.0. It's a function decorator that compiles your PyTorch models into optimized code, making them run faster without requiring you to change your model architecture or training code. Compiling a model is as easy as passing it into a function.

import torch
from models.qwen3 import Qwen3Model

model = Qwen3Model() 
compiled_model = torch.compile(model)

Compilation process involves several stages:

  1. Graph capture - PyTorch traces the model execution and captures the computational graph dynamically using TorchDynamo
  2. Graph optimization - captured graph is broken into parts and optimized using backends like TorchInductor
  3. Code generation - optimized graph is compiled into efficient machine code using Triton (or CPP for CPU)
  4. Execution - compiled model is cached and reused for subsequent calls with similar input shapes

With torch.compile enabled we can get around 57.38 tokens per second which is over 2 times more then regular model.

Main benefits of using torch.compile are:

  1. speed improvements typically 30-200 % (in out case even more)
  2. optimized memory access patterns
  3. ease of use and compatibility with existing code (in most cases)
  4. automatic optimizations like operator fusion, dead code elimination, kernel fusion and so on

Using torch.compile is great when:

  • running model in production environment
  • during training of large model where small speedups build up over time
  • running models with typical structure and architecture, because more advanced torch mechanics might not be compiled well, since this compiler is quite new and still in development

When not to use torch.compile:

  • when using advanced and non-traditional architectures as described in last point
  • when input data changes each time model is run - this causes recompilation of a model
  • compilation can take additional resources, so if running in memory restrictive environments this might be a problem
  • if model has data-dependent control flow which does not compiles well

It is worth remembering that compilation process takes time, so first run of the model will be slower which might be an issue in some cases.