LogSumExp is often used in machine learning. It has the following form:
\[\begin{equation} LSE(x_1, x_2, \ldots, x_n) = \log\sum_{i=1}^{N}\exp(x_i) \end{equation}\]
LSE as an upper bound for max()
LSE is an upper bound for \(max(x_1, x_2, \ldots, x_n)\) (The equality establishes only when \(n=1\)). We have the following inequality:
\[\begin{equation} max(x_1, x_2, \ldots, x_n) <= \log\sum_{i=1}^{N}\exp(x_i) \end{equation}\]
To verify this, I have drawn a graph comparing LSE vs max for 2D data in the range of [-1, 1]. The graph is shown in the title image. It clear that LSE is indeed an uppber bound for max.
click here to check the code used to generate the title image
import matplotlib.pyplot as plt
import numpy as np
def main():
= plt.figure()
fig = fig.add_subplot(projection='3d')
ax
= 200
N = np.linspace(-1, 1, N)
x1 = np.linspace(-1, 1, N)
x2
= np.meshgrid(x1, x2)
X1, X2 = np.log(np.exp(X1) + np.exp(X2))
Y = ax.plot_surface(X1, X2, Y, color='red')
surf
= np.max(np.stack([X1, X2], axis=0), axis=0)
Y_up = ax.plot_surface(X1, X2, Y_up, color='blue')
surf2
# y = np.log(np.exp(x1) + np.exp(x2))
# ax.plot_trisurf(x1, x2, y, color='red')
# y_up = np.max(np.stack([x1, x2], axis=1), axis=1)
# ax.plot_trisurf(x1, x2, y_up, color='green')
"X1")
ax.set_xlabel("X2")
ax.set_ylabel("Z")
ax.set_zlabel(
# change the 3D plot angel and dist, https://stackoverflow.com/q/12904912/6064933
=11, azim=-46)
ax.view_init(elev= 10
ax.dist
# plt.show()
"log_sum_exp_vs_max.pdf", bbox_inches='tight')
plt.savefig(
if __name__ == "__main__":
main()
How is this information useful? Well, we can use it to transform the optimization target. For example, you may want to optimize \(max(x_1, x_2)\), which is not differentiable. Then we can optimize \(LSE(x_1, x_2)\) instead.
In Lifted structure loss, they used this trick to transform equation 3 to equation 4. Without this knowledge, you will find it difficult to understand how they arrive at equation 4.
LSE for numerical stability
When we use softmax function to normalize a vector and the vector contains large or small values, we will encounter numerical issues (overflow or underflow). We need to use LSE to alleviate this issue.
This post explains how it works in detail.
References
- https://en.wikipedia.org/wiki/LogSumExp
- matplotlib plot_trisurf