Snakes and ladders: Accelerating state space model inference with speculative decoding
2024
                            
                            
    
                                    Speculative decoding is a method for accelerating inference in large language models (LLMs) by predicting multiple tokens using a smaller ‘draft model’ and validating them against the larger ‘base model.’ If a draft token is inconsistent with what the base model would have generated, speculative decoding ‘backtracks’ to the last consistent token before resuming generation. This is straightforward in autoregressive Transformer architectures since their state is a sliding window of past tokens. However, their baseline inference complexity is quadratic in the number of input tokens. State Space Models (SSMs) have linear inference complexity, but they maintain a separate Markov state that makes backtracking non-trivial. We propose two methods to perform speculative decoding in SSMs: “Joint Attainment and Advancement” and “Activation Replay.” Both methods utilize idle computational resources to speculate and verify multiple tokens, allowing us to produce 6 tokens for 1.47⇥ the cost of one, corresponding to an average 1.82⇥ wall-clock speed-up on three different benchmarks using a simple n-gram for drafting. Furthermore, as model size increases, relative overhead of speculation and verification decreases: Scaling from 1.3B parameters to 13B reduces relative overhead from 1.98⇥ to 1.22⇥. Unlike Transformers, speculative decoding in SSMs can be easily applied to batches of sequences, allowing dynamic allocation of resources to fill gaps in compute utilization and thereby improving efficiency and throughput with variable inference traffic.
                                
                            
                            
                                
        Research areas
    
    
        
    
 
     
     
    