(C) PLOS One This story was originally published by PLOS One and is unaltered. . . . . . . . . . . Bayesian reconstruction of memories stored in neural networks from their connectivity [1] ['Sebastian Goldt', 'International School Of Advanced Studies', 'Sissa', 'Trieste', 'Florent Krzakala', 'Idephics Laboratory', 'Ecole Polytechnique Fédérale De Lausanne', 'Epfl', 'Lenka Zdeborová', 'Spoc Laboratory'] Date: 2023-03 We also provide a reference implementation of Low-RAMP for symmetric and bipartite matrix factorisation problems applicable to a number of different problems. It is designed to be easily extendable to other problems and also provides a number of further utility functions. All the results in this paper can be reproduced using this code. The inference problem considered here, where we aim to recover a symmetric low-rank matrix from noisy observations, can be solved using a class of approximate message passing (AMP) algorithms for low-rank matrix factorisation called Low-RAMP. It was derived by Lesieur et al. [ 52 ], building on previous works [ 44 , 53 , 54 ] that provided AMP algorithms for particular instances of low-rank matrix factorisation problems. Low-RAMP is an iterative algorithm that produces estimates for the mean of the marginal distribution of p( x i ) and their covariance matrix σ i , where x i is in general the ith column of the low-rank matrix X that we are estimating by evaluating the posterior distribution ( 6 ). In the present case, is the mean of the estimated ‘tuning curve’ of the ith neuron (see above). Using this framework, we will derive variants of the algorithm for the pattern reconstruction problem outlined in the previous section. We present the algorithm in detail in the Methods section. So to summarise, statistical physics gives us an algorithm to perform approximate inference of the patterns and the state evolution Eq (12) allows us to track the behaviour of the algorithm over time. We can thus analyse the performance of the algorithm in high-dimensional inference by studying the fixed points of the low-dimensional state-evolution ( 12 ). This is the key idea behind this approach, and we will now demonstrate the usefulness of this machinery by applying it to several specific cases. State evolution provides an update equation for the order parameter M t that mirrors the progress of the algorithm. We first define an auxiliary function (11) where and . If A = 0 and b = 0, this function would compute the average over the prior distribution p X ( x ). Instead, b and A are estimated from the data (see the algorithm for details) so f computes an average over a distribution that contains the prior and a data-dependent part. This structure reflects the Gaussian approximation of the posterior density that we apply here, or more broadly speaking the interplay between prior information and data-dependent likelihood that is typical of Bayesian inference. Consequently, is a normalisation factor. The update equation for the order parameter M t can be written using this auxiliary function for all the cases considered in this paper; it reads [ 52 ] (12) where z is a P-dimensional vector of Gaussian random variables with mean zero and variance 1. The average over x 0 is taken with respect to the prior distribution p X ( x ), as discussed above. Now the goal is to find an update equation for the order parameter M t that mirrors the progress of the algorithm. This update equation is the state evolution equation [ 50 , 55 ]. Remarkably, from [ 59 ] we see that the two constants defining our problem, τ and ν, do not appear explicitly in the state evolution equations. Instead, the behaviour of the algorithm—and hence its performance—only depends on an effective signal-to-noise ratio (SNR) of the problem, which is a function of the threshold τ and noise variance ν utilised in the connectivity structure ( 3 ). Formally, it can be expressed as the inverse of the Fisher score matrix [ 60 ] of the generative model we use to describe how the network is connected ( 5a ) and ( 5b ), evaluated at : (9) (10) Here and throughout, denotes the expectation over the random variables. In fact, on the level of the algorithm, everything about the output channel ( 5 ) can be summarised in this single, scalar quantity Δ. This remarkable universality of the state evolution and hence the AMP algorithm with respect to the output channel was first observed in [ 59 ] and dubbed “channel universality”. Since we are adopting a probabilistic approach to estimating the patterns, we will call the reconstruction of the patterns the mean of the posterior distribution, which we denote by a hat: . Our goal is to track the mean-squared error mse X of the reconstruction of the true signal after t steps of the algorithm, (7) where denotes the Euclidean norm of a vector. The mse can be expressed in terms of a single matrix-valued parameter defined as (8) such that . Here and throughout this paper we write averages with respect to the prior distribution p X ( x ) of the corresponding model as 〈⋅〉. We write x 0 with the subscript to underline that the random variable x 0 is not a column of the matrix X that we’re trying to evaluate; instead, it is a variable that is drawn from the prior and averaged over. The AMP algorithm has the distinctive advantage over other algorithms, such as Monte Carlo methods, that its behaviour in the limit N → ∞ for separable prior on the X *, random i.i.d. noise ζij, and number of patterns P = O(1), can be tracked exactly and at all times using the “state evolution” technique [ 50 , 55 ]. The roots of this method go back to ideas originally introduced in physics to deal with a class of disordered systems called glasses [ 56 , 57 ]. For the low-rank matrix factorisation problems we consider here, state evolution was derived and analysed in detail by Lesieur et al. [ 52 ], building on previous works that derived and analysed state evolution for other specific problems [ 44 , 45 , 53 ]. The last few years in particular have seen a surge of interest in using state evolution to understand the properties of approximate Bayesian inference algorithms for a number of problems [ 58 ]. Reconstructing binary patterns As a first application of the algorithm and the analysis tools outlined so far, we consider the reconstruction of a set of binary patterns, . We will assume that both positive and negative values are equiprobable and that the components of a pattern vector are independent of each other, so the prior on a column of the matrix of stored patterns, x i , is simply (13) A single pattern (P = 1). It is instructive and helpful for the following discussions to first consider the case where P = 1, i.e. there is only a single pattern stored in the network that we are trying to recover from J. The threshold function for the model then becomes f(A, B) = tanh(B), with , and the state evolution for the now scalar parameter mt simplifies to (14) where w is a scalar Gaussian random variable with zero mean and unit variance. We can now iterate the state evolution Eq (14) with a given noise level Δ(ν, τ) until convergence and then compute the mse corresponding to that fixed point. The fixed point we converge to reveals information about the performance of the AMP algorithm. We plot the results on the left-hand side of Fig 1 for the two different initialisations of the algorithm: in blue, we plot the mse obtained by iterating SE starting with an random initialisation (15) where δ > 0 is a very small random number. The error obtained in this way is the one that is obtained by the AMP algorithm when initialised with a random draw from the prior distribution—in other words, a random guess for the patterns. This is confirmed by the blue crosses, which show the mean and standard deviation of the mse obtained from five independent runs of the algorithm on actual instances of the problem. The dashed orange line in Fig 1 shows the final mse obtained from an informed initialisation (16) which would correspond to initialising the algorithm with the solution, i.e. . PPT PowerPoint slide PNG larger image TIFF original image Download: Fig 1. (Left) Reconstruction and performance of the message-passing algorithm for binary patterns. We plot the mse (7) obtained by the AMP algorithm (32) as a function of the effective noise Δ (9) (blue crosses). We plot the performance of the algorithm starting from random (15) and informed (16) initialisations. Solid lines depict the prediction obtained from iterating the state evolution Eq (14). Having Δ/Δ c > 1 corresponds to the white region in the phase diagram on the right. We also plot the mse of the reconstruction obtained by applying PCA to the weight matrix J and to the Fisher matrix S (31) (green and red, resp.) Parameters: τ = 0. N = 5000 for AMP, N = 20000 for PCA. (Center) Phase diagram for the rectified Hopfield channel with P = 1. We plot whether reconstruction of the patterns better than a random guess is easy (blue) or impossible (white) using the message-passing algorithm as a function of the constant threshold τ and the variance ν of the Gaussian noise appearing in the connectivity structure (3). The solid lines are the contours of the connection probability p C (ν, τ) (4). (Right) Critical noise ν* as a function of connection probability p C . We plot ν*, the largest variance of the additive Gaussian noise ζ ij at which reconstruction remains possible, against the probability p C (4) that any two neurons are connected. https://doi.org/10.1371/journal.pcbi.1010813.g001 In this model, we find that the AMP algorithm starting from a random guess performs just as well as the algorithm starting from the informed initialisation. This need not always be the case, and we will indeed find a different behaviour in the next sparse and skewed models we consider. When is recovery possible? We can see from the middle plot of Fig 1 that recovery of the memories from the connectivity J is not always possible; there exists a critical value for the effective noise Δ c above which the mean-squared error of the solution obtained by the algorithm is the same as we would have obtained by making a random guess for the solution based on the prior distribution (13) alone, without looking at the data. We can calculate this critical noise level Δ c using the state evolution (12). We can see from that equation that mt = 0 is a trivial fixed point, in the sense that the mse corresponding to that fixed point is equal to the mse obtained by making a random guess. Expanding Eq (14) around this fixed point yields mt+ 1 = mt/Δ. There are hence two regimes for recovery, separated by a critical value (17) of the effective noise (9). If Δ > Δ c , the uniform fixed point is stable and recovery is impossible. On the other hand, for Δ < Δ c , the uniform fixed point is unstable and hence AMP returns an estimate for the patterns that has an mse that is lower than random guessing. The phase diagram in the middle of Fig 1 delineates the easy and the impossible phase for the rectified Hopfield channel with symmetric prior (13). While there could be in principle other fixed points of the state evolution equations for other priors and channels [52], it is always one of the fixed points that is reached from either the informed or the uninformed initialisation that describes the behaviour of the algorithm. At first sight, the impact of the additive Gaussian noise ζ ij on the phase diagram in Fig 1 appears counter-intuitive. If we fix the threshold to, say, τ = 0.5, reconstruction is impossible for small variances ν of ζ ij . As we increase ν, i.e. as we add more noise to the system, recovery becomes possible. The key to understanding this behaviour is that for a single stimulus P = 1, a weight in the network will have one of two possible values which are symmetric around the origin, . By applying the rectification, for any cut-off τ > W ij the resulting weight matrix J without additive noise is trivially zero and no recovery is possible. We can only hope to detect something when an added noise ζ ij pushes the value of the weight before rectification above the cut-off. Recovery then becomes possible if the added noise is large enough that the weight without noise is larger than the cut-off a + ζ ij > τ, while remaining small enough that it’s significantly more likely that the noise-less weight is positive than negative. As the noise variance increases even further, its detrimental effects dominate, and recovery becomes impossible again. This mechanism is reminiscent of stochastic resonance (SR), a mechanism where a weak signal is amplified by the presence of noise. Indeed, our problem contains the three ingredients for SR (e.g. [61]): A threshold mechanism, given by the rectification in the connectivity structure: A weak signal (the stored patterns); and a noise term, ζ. As already mentioned, when noise is too large recovery becomes impossible. We show on the right of Fig 1 the critical variance of ζ ij above which reconstruction becomes impossible, ν*, as a function of the connection probability p C , given by Eq (4). This plot can be obtained by solving, for a given value p C = c, the two-dimensional system (18) (19) for (τ, ν). As expected, the critical variance increases with the connection probability, and it goes to zero as the connection probability goes to zero. Comparison to principal component analysis (PCA). Principal component analysis (PCA) is another method to reconstruct the stored patterns from the network connectivity. PCA and other spectral methods have some advantages: they are non-parametric, and their implementation in the case of a single pattern is straightforward: the PCA prediction for the stored pattern is simply the leading eigenvector of J. We plot the mean-squared error (7) of this estimate with the green line on the left of Fig 1, where we see that the reconstructing error of PCA is larger than the one of AMP, especially for large values of the noise. This is also borne out by theory: the reconstruction mean-squared error of PCA can be shown to be strictly larger than the AMP estimate, since the latter is the Bayes-optimal predictor [52]. An alternative PCA algorithm can be found by linearising the AMP equations around the trivial fixed point [58, 62]. This linearisation yields an equation that can be interpreted as PCA applied to the Fisher matrix S (31) instead of J. Since the Fisher matrix depends on the generative model for the data when deriving the message-passing equations, looking at its leading eigenvector offers a spectral algorithm that is more adapted to the problem at hand. Indeed, we find that its error (red line in Fig 1) is slightly lower than the error obtained from PCA on the weights directly. In either case, the performance of PCA is worse than that of AMP. The large value of the PCA error compared to the AMP error at large noise levels in Fig 1 reveals a fundamental weakness of PCA: even at noise levels above the critical noise Δ c , where no reconstruction is possible for any algorithm, PCA can be applied and will return a prediction—there is no concept of uncertainty in PCA. Hence the mse of PCA tends to a constant as the noise increases and the leading eigenvector of J is just a random vector; for the Hopfield prior and when rescaling the eigenvectors to have the same length as draws from the Hopfield prior, this constant is 2. AMP on the other hand returns a vector full of zeros if Δ > Δ c (and the prior has an average of 0, as is the case for all the priors we consider). AMP thus expresses its uncertainty about the planted pattern, yielding an mse = 1 for inputs with x i = ±1. The advantage of the Bayesian approach is thus that it prevents over-confident predictions in the high noise regime. The weaker performance of PCA compared to AMP is due to the fact that spectral methods do not not offer a natural way to incorporate the prior knowledge we have about the structure and distribution of stimuli into the recovery algorithm. The Bayesian framework incorporates this domain knowledge in a transparent way through the generative model of the stored patterns p X (x). We will see that this creates an even larger performance gap for sparse patterns and patterns with low coding level. Many patterns (P > 1). For the general case of several patterns P > 1 with finite P, we can significantly simplify the state evolution by noticing that the matrix Mt will interpolate between a matrix full of zeros at time t = 0 and a suitably scaled identity matrix in the case of perfect recovery, i.e. (20) where I P is the identity matrix in P dimensions. In other words, for uncorrelated patterns, the different input patterns do not interact during the reconstruction, and so the off-diagonal matrix elements remain zero in the case where we only store a few patterns and the connectivity structure remains low-rank. Once we overload the model by storing many more patterns, we would have non-zero off-diagonal elements, meaning that reconstructions converge to spurious patterns, for example linear combinations of the patterns. However, in this regime the state evolution derived above also breaks down. In this case, the threshold function becomes (21) where z k is again a standard Gaussian variable. Substituting into the state evolution gives an update equation for the parameter mt, namely (22) where mt is the overlap parameter introduced above (20). This update has the same form as the state evolution in the P = 1 case, Eq (14). So we find, remarkably, that recovering P distinct patterns is exactly equivalent to recovering a single pattern P times in the thermodynamic limit where N → ∞ while the number of patterns is of order . This approximation will eventually break down in practical applications with finite network sizes, and we investigate the breaking point of this behaviour below. Recovering many patterns with PCA poses an additional challenge. While it is easy to recover the leading rank-P subspace of the matrices J or S by simply computing the P leading eigenvectors, it is not clear how to recover the exact patterns from these eigenvectors, which can be any rotation of the input patterns due to the rotational symmetry of W. This can be seen from the fact that the patterns X* could be multiplied by any rotational matrix O with without changing the resulting weight matrix J, see Eq (2). The best way to recover the exact stimuli from the principal components is thus not clear a priori (see [63]). Other problems require combining PCA with other methods, such as k-means or gradient descent. Since we have already seen that AMP outperforms PCA on binary patterns, and we will see that this gap only increase for the other types of patterns we will study below, we do not investigate further this direction. [END] --- [1] Url: https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1010813 Published and (C) by PLOS One Content appears here under this condition or license: Creative Commons - Attribution BY 4.0. via Magical.Fish Gopher News Feeds: gopher://magical.fish/1/feeds/news/plosone/