Wonder if you get any numerical instability here in high dimensions by doing a sum of exponentials? Probably not because they’re Gaussian (no long tails) but after looking at scipy.special.logsumexp [1] I’m a bit wary of sums of exponentials with float32. Would be curious to see if there’s any characterization of this (the cited paper in the article only considers the low dimensional case)
[1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.s...