Adaptive Sampling for Efficient Softmax Approximation

Part of Advances in Neural Information Processing Systems 37 (NeurIPS 2024) Main Conference Track

Bibtex Paper Supplemental

Authors

Tavor Baharav, Ryan Kang, Colin Sullivan, Mo Tiwari, Eric Luxenberg, David Tse, Mert Pilanci

Abstract

The softmax function is ubiquitous in machine learning and optimization applications. Computing the full softmax evaluation of a matrix-vector product can be computationally expensive in high-dimensional settings. In many applications, however, it is sufficient to calculate only the top few outputs of the softmax function. In this work, we present an algorithm, dubbed AdaptiveSoftmax, that adaptively computes the top k softmax values more efficiently than the full softmax computation, with probabilistic guarantees. We demonstrate the sample efficiency improvements afforded by AdaptiveSoftmax on real and synthetic data to corroborate our theoretical results. AdaptiveSoftmax yields >10x gain over full softmax computation on most datasets, yielding up to 30x improvement for Mistral7B evaluated on the Wikitext dataset. The adaptive method we propose for estimating the partition function (the softmax denominator) is of independent interest and can be used in other applications such as kernel density estimation.