Mixture Models Review and Intuition to State Space Models
Mixture model has the observation X and their corresponding hidden latent variable Y which indicates the source of that observation X. Observation X can be discrete like word frequency, or continuous like temperature. Variable Y denotes which mixture/distribution X comes from. When both X and Y are continuous, it is called factor analysis.
The advantage with mixture models is that smaller graphs (mixture models) can be taken as building blocks to make bigger graphs (HMMs)
Some common inference algorithms for HMMs can range from inferring on one hidden state, all hidden states, and even the most probable sequence of hidden states. Example algorithms are Viterbi algorithm, Forward/Backward algorithm, Baum-Welch algorithm.
All the above algorithms have a counterpart in State Space models - despite the mathematical technique being different, they have the same essence. A similarity to the above algorithms that it shares is that it can be broken down into local operations/subproblems and continuously built to the whole solution. We will further explain in the coming sections.
Stories/Intuitions for the various models
HMM - Dishonest casino story
Factorial HMM - Multiple dealers at the dishonest casino
State Space Model (SSM) - X – Signal of an aircraft on the radar, Y – actual physical locations of the aircraft
Switching SSM - multiple aircrafts (State S is the indicator of which aircraft is appearing on the radar)
Basic Math Review
A multivariate Gaussian is denoted by the following PDF -
Consider a block partitioned matrix M =
\begin{bmatrix} E & F \\ G & H \end{bmatrix}
First we diagonalize M -
\begin{bmatrix} I & -FH^{-1} \\ 0 & I \end{bmatrix}\begin{bmatrix} E & F \\ G & H \end{bmatrix}\begin{bmatrix} I & 0 \\ -H^{-1}G & I \end{bmatrix} = \begin{bmatrix} E-FH^{-1}G & 0 \\ 0 & H \end{bmatrix}
This is called the Schur's complement -
M/H = E-FH^{-1}G
Then we inverse using the formula -
\begin{aligned}
XYZ = W \implies Y^{-1} = ZW^{-1}X
\implies M^{-1} = \begin{bmatrix} E & F \\ G & H \end{bmatrix}^{-1} = \begin{bmatrix} I & 0 \\ -H^{-1}G & I \end{bmatrix}\begin{bmatrix} (M/H)^{-1} & 0 \\ 0 & H^{-1} \end{bmatrix}\begin{bmatrix} I & -FH^{-1} \\ 0 & I \end{bmatrix} \\
= \begin{bmatrix} (M/H)^{-1} & -(M/H)^{-1}FH^{-1} \\ -H^{-1}G(M/H)^{-1} & H^{-1} + H^{-1}G(M/H)^{-1}FH^{-1} \end{bmatrix} \\
= \begin{bmatrix} E^{-1} + E^{-1}F(M/E)^{-1}GE^{-1} & -E^{-1}F(M/E)^{-1} \\ -(M/E)^{-1}GE^{-1} & (M/E)^{-1} \end{bmatrix}
\end{aligned}
Hence, by matrix inverse lemma,
(E - FH^{-1}G)^{-1} = E^{-1}+ E^{-1}F(H-GE^{-1}F)^{-1}GE^{-1}
Imagine a point on a sheet of paper x \in \mathcal{R}^2
If the orientation or the axes are changed, it could be viewed as a point in 3D-space as well. Depending on whether one is sitting in the room or on the paper, we will see the point in 2D or 3D. Hence Y is what is observed from our view (room or on paper), and X (the original point) is what is the hidden/latent variable.
We know that, a marginal gaussian (p(X)) times a conditional gaussian (p(Y|X)) is a joint gaussian, and a marginal (p(Y)) of a joint gaussian (p(X, Y)) is also a gaussian. Hence, we can compute the mean and variance of Y.
Assuming noise W is uncorrelated with the data,
We can also say, that the effective covariance matrix is a low-rank outer product of two long skinny matrices plus a diagonal matrix. In other words, factor analysis is just a constrained Gaussian model.
We will now analyse the Factor analysis joint distribution assuming noise is uncorrelated with the data or the latent variables -
The distributions as we derived/assumed above are-
\begin{aligned}
&P(X) = \mathcal{N}(X;0, I) \\
&P(Y | X) = \mathcal{N}(Y; \mu+\Lambda X, \Psi)
\end{aligned}
The covariance between X and Y can be derived as follows -
\begin{aligned}
Cov[X, Y] = E[(X - 0)(Y - \mu)^T] = E[X(\mu + \Lambda X + W -\mu)^T] \\
= E[XX^T\Lambda^T + XW^T] = \Lambda^T
\end{aligned}
Hence the joint distribution of X and Y is -
P(\begin{bmatrix} X \\ Y \end{bmatrix}) = \mathcal{N}(\begin{bmatrix} X \\ Y \end{bmatrix} | \begin{bmatrix} 0 \\ \mu \end{bmatrix}, \begin{bmatrix} I & \Lambda^T \\ \Lambda & \Lambda \Lambda^T + \Psi \end{bmatrix})
Now we can say -
\begin{aligned}
&\Sigma_{11} = I \\
&\Sigma_{12} = \Sigma_{21}^T = \Lambda^T \\
&\Sigma_{22} = \Lambda \Lambda^T + \Psi
\end{aligned}
Given all of the above, we can now derive the posterior of the latent variable X given Y, where
\begin{aligned}
&P(X|Y) = \mathcal{N}(X | \mu_{1|2}, V_{1|2}) \\
&\mu_{1|2} = \mu_1 + \Sigma_{12}\Sigma_{22}^{-1}(X_2 - \mu_2) \\
& = \Lambda^T(\Lambda\Lambda^T + \Psi)^{-1}(Y - \mu) \\
&V_{1|2} = \Sigma_{11} - \Sigma_{12}\Sigma_{22}^{-1}\Sigma_{21}
= I - I\Lambda^T(\Lambda\Lambda^T + \Psi)^{-1}\Lambda I
\end{aligned}
Applying the matrix inversion lemma we learnt above, we get -
\begin{aligned}
&(E - FH^{-1}G)^{-1} = E^{-1}+ E^{-1}F(H-GE^{-1}F)^{-1}GE^{-1} \\
&V_{1|2} = (I + \Lambda^T\Psi^{-1}\Lambda)^{-1} \\
&\mu_{1|2} = V_{1|2}\Lambda^T\Psi^{-1}(Y-\mu)
\end{aligned}
Learning in Factor Analysis
The inference problem, as shown above can be thought of as a linear projection
since the posterior covariance (V_{1|2} above) does not depend on the observed
data y and the posterior mean \mu_{1|2} is just a linear operation.
The learning problem in Factor Analysis corresponds to learning the parameters of the model,
i.e., the loading matrix \Lambda, manifold center \mu and the diagonal covariance
matrix of the noise, \Psi. As always, we will solve this problem by maximizing the
log-likelihood of the observed data,
i.e. (μ∗,Λ∗,Ψ∗)= argmax ℓ(D;μ,Λ,Ψ).
Thanks to the derivation above, we have a closed-form expression for the
incomplete log likelihood of the data,
i.e. \mathcal{D} = \text{\textbraceleft} y^{(i)} : i=1,…,n \text{\textbraceright} , as shown below:
Estimation of \mu is straightforward, however \Lambda and \Psi are tightly coupled
non-linearly in log-likelihood. In order to make this problem tractable, we will use the
same trick used in Gaussian Mixture Models, i.e., optimize the complete log-likelihood
using Expectation Maximization (EM) algorithm.
In the M-step, we take derivatives of the expected complete log-likelihood calculated
above wrt the parameters and set them to zero. We shall use the trace and determinant
derivative rules derived above in this step.
It should be noted that this model is “unidentifiable” in the sense that different runs
for learning parameters for the same dataset are not guaranteed to obtain the same solution.
This is because there is degeneracy in the model, since \Lambda only occurs in the product
of the form \Lambda \Lambda^T, making the model invariant to rotations and axis flips in the
latent space. To see this more clearly, if we replace \Lambda by \Lambda Q for any
orthonormal matrix Q, the model remains the same because
(ΛQ)(ΛQ)T=Λ(QQT)ΛT=ΛΛT.
Introduction to State Space Models (SSM)
In the figure above, the dynamical (sequential) continuous counterpart to HMMs
are the State Space Models (SSM). In fact, they are the sequential extensions of Factor Analysis
discussed so far. SSM can be thought of as a sequential Factor Analysis or continuous state HMM.
Mathematically, let f(\cdot) be any arbitrary dynamic model, and let g(\cdot) be
any arbitrary observation model. Then, we obtain the following equations for the dynamic system:
Here w_t \sim \mathcal{N}(0, Q) and v_t \sim \mathcal{N}(0, R) are zero-mean Gaussian noise.
Further, if we assume that f(\cdot) and g(\cdot) are linear (matrices), then we get equations for a linear
dynamical system (LDS):
An example of an LDS for 2D object tracking would involve observations y_t to be 2D positions
of the object in space and x_t to be 4-dimensional, with the first two dimensions corresponding
to the position and the next two dimensions corresponding to the velocity, i.e., the first derivative
of the positions. Assuming a constant velocity model, we obtain the state space equations for
this model, as shown below:
There are two different inference problems that can be considered for SSMs:
The first problem is to infer the current hidden state (x_t) given the observations up to
the current time t, i.e., y_1, y_2, ..., y_t. This problem is analogous to the forward
algorithm in HMMs, is exact online inference problem, known as the \textbf{Filtering} problem and will be discussed next.
The second problem is the offline inference problem, i.e., given the entire sequence y_1, y_2, ..., y_T,
estimate x_t for t < T. This problem, called the \textbf{Smoothing} problem, can be solved using
the Rauch-Tung-Strievel algorithm and is the Gaussian analog of the forward-backward (alpha-gamma) algorithm
in HMMs.
Kalman Filter
The Kalman Filter is an algorithm analogous to the forward algorithm for HMM. For the state space model (SSM), the goal of Kalman filter is to estimate the belief state P(X_{t}|y_{1:t}) given the data {y_{1},…,y_{t}}. To do so, it mainly follows a recursive procedure that includes two steps, namely time update and measurement update. Here, we are going to derive the two steps.
Derivation
Time Update:
(1) Goal: Compute P(X_{t+1}|y_{1:t}) (the distribution for X_{t+1|t}) using prior belief P(X_{t}|y_{1:t}) and the dynamical model P(X_{t+1}|X_{t}).
(2) Derivation: Recall that the dynamical model states that Xt+1=AXt+Gwt where w_{t} is the noise term with a Gaussian distribution \mathcal{N}(0; Q). Then, we can conpute the mean and variance for X_{t+1} using this formula.
(1) Goal: Compute P(X_{t+1}|y_{1:t+1}) using P(X_{t+1}|y_{1:t}) computed from time update, observation y_{t+1} and observation model P(y_{t+1}|X_{t+1}).
(2) Derivation: The key here is to first compute the joint distribution P(X_{t+1},y_{t+1}|y_{1:t}), then derive the conditional distribution P(X_{t+1}|y_{1:t+1}). Recall that the observation model states that yt=CXt+vt where v_{t} is the noise term with a Gaussian distribution \mathcal{N}(0; R). We have already computed the mean and the variance of X_{t+1|t} in the time update. Now we need to compute the mean and the variance of y_{t+1|t} and the covariance of X_{t+1|t} and y_{t+1|t}.
where K_{t+1} = P_{t+1|t}C^T(CP_{t+1|t}C^T+R)^{-1} is referred to as the Kalman gain matrix.
Example
We now see an example of Kalman filter on a problem on noisy observations of a 1D particle doing a random walk. The SSM is given as:
\begin{aligned}
X_{t+1|t} & = X_{t} + w \\
y_{t} & = X_{t} + v
\end{aligned}
where w \sim \mathcal{N}(0; \sigma_{x}) and v \sim \mathcal{N}(0; \sigma_{y}) are noises. In other words, A=G=C=I. Thus, using the update rules derived above we have that at time t:
As demonstrated in the figure below, given prior belief P(X_0), our estimate of P(X_1) without an observation has the same mean and larger variance as shown by the first two lines in equations above. However, given an observation of 2.5, the new estimated distribution shifted to the right accoding to last three lines of equations above.
Complexity of one KF step:
Let X_t \in R^{N_x} and Y_t \in R^{N_y}
Predict step:
P_{t+1|t} = AP_{t|t}A + GQG^T
This invloves matrix multiplication of N_x x N_x matrix with another N_x x N_x matrix, hence the time complexity is O(N_{x}^{2})
This invloves a matrix inversion of N_y x N_y, hence the time complexity is O(N_{y}^{3})
Overall time = max{N_{x}^{2},N_{y}^{3}}
Rauch-Tung-Strievel
The Rauch-Tung-Strievel algorithm is a Gaussian analog of the forwards-backwards(alpha-gamma) algorithm for HMM. The intent is to estimate P(X_t|y_{1:T}).
Since, P(X_t|y_{1:T}) is a Gaussian distribution, we intend to derive mean X_{t|T} and variance P_{t|T}.
Let’s first define the joint distribution P(X_{t},X_{t+1}|y_{1:t}).
m = \begin{bmatrix} \hat{x}_{t|t} \\ \hat{x}_{t+1|t} \end{bmatrix}, V = \begin{bmatrix} P_{t|t} & P_{t|t}A^T \\ AP_{t|t} & P_{t+1|t} \end{bmatrix}
where m and V are mean and covariance of joint distribution respectively.
We are going to use a neat trick to arrive at the RTS inference.
E-step: Compute <X_tX_{t-1}^T>, <X_tX_t^T> and <X_t> using Kalman Filter and Rauch-Tung-Strievel methods.
M-step: Same as explained in the Factor Analysis section.
Non-Linear Systems
In some of the real world scenarios, the relation between model states as well as states and observations may be non-linear. This renders a closed form solution to the problem almost impossible. In recent works, this non-linearity relation is captured using deep neural networks and stochastic gradient descent based methods are used to obtained the solutions. These state space models can be represented as:
x_t = f(x_{t-1}) + w_t
y_t = g(x_t) + v_t
Since the effect of the noise covariance matrices Q and R remains unchanged due to non-linearity, they have been omitted from the discussion for convenience.
An approximate solution without using deep neural networks is to express the non-linear functions using Taylor expansion. The order of expansion depends on the use case and the extent of the non-linearity. These are referred to as Extended Kalman Filters. The following equations show the second order Taylor expansion for Extended Kalman Filters:
Online Inference: Inference methods that are based on looking only at the observations till the current time-step. They generally take the form P(X|y_{1:t}). The inference objective may vary depending on the algorithm. Eg, Kalman Filter.
Offline Inference. Inference methods than consider all the observations from all the time-steps. The take the form P(X|y_{1:T}). Eg, Rauch-Tung-Strievel smoothing.
Recursive Least Squares and Least Mean Squares
Consider a special case where the state x_t remains constant while the observation coefficients (C) vary with time. In this scenario, observations are affected by the coefficients rather than the state itself. This turns the estimation of the state into a linear regression problem.
Let x_t = \theta and C = x_t in the KF update equation. Then we can represent the observation models as
y_t = x_t\theta + v_t
which takes exactly the form of a linear regression problem to estimate \theta.
Now, using the Kalman Filter idea we can formulate the update equation for this problem as:
This is the recursive least squares algorithm as one can see that the update term takes the form of the derivative of square of the difference. If we treat \eta = P_{t+1}R^{-1} as a constant this exactly becomes least mean squares algorithm.