The Baum-Welch algorithm for fMRI data

The Baum–Welch algorithm makes use of Expectation Maximization (EM) to estimate the parameters of a hidden Markov model. Here is a demo implementation for hypothetical fMRI data that estimates the neural signatures of K = 3 states from V = 3 voxels over NTime = 256 time points. Models such as these are useful for modelling, for example, the changes in neural activity associated with a sequences of different situational events, see: Baldassano, Chen, Zadbood, Pillow, Hasson & Norman (2017). Discovering event structure in continuous narrative perception and memory. Neuron, 95(3), 709-721.

Generate ground truth variables (GT_*) for hypothetical data

% GT_State2Feature is a K by V matrix that lists the ground truth mean voxel
% intensities (number of voxels = V) related to each of the K sates.
GT_State2Feature = ...
[...
1,-2,-2;
2, 1, 2;
-1, 2, 0];
% GT_TransMat is a K by K transition matrix that details the probability of
% jumping from one state to another. GT_CumTransMat represents the
% cumulative probably of transitioning from state i to state 1 to j.
GT_TransMat = ...
[...
.980, .020, .000;
.000, .990, .010;
.000, .000, 1.0];
GT_CumTransMat = cumsum(GT_TransMat,2);
Note that expected number of time steps per state, E[n], is very sensitive to changes in the transition probabilities (p) when p is high (p ~> .9). This is because the expected number of states is given by the following:
% Plot p vs E[n]:
Line = fplot(@(p) 1./(1-p),[0,1]);
Line.Color = 'r';
Line.LineWidth = 2;
ylim([0,144]);
set(gca,'FontSize',14);
xlabel('p');
ylabel('E[n]');
% GT_Time2State (generated below) is a NTime by 1 vector representing the
% ground truth sequence of states. This is generated recursively at each
% time point by starting in the state i=1, sampling a uniformly distributed
% random number (x) in the range (0,1), and jumping to state j when
% GT_CumTransMat(i,j) > x.
NTime = 256;
GT_Time2State = nan(NTime,1);
GT_Time2State(1,1) = 1;
rng(331);
for iTime = 2:1:NTime
x = rand(1,1);
xx = GT_CumTransMat(GT_Time2State(iTime-1),:);
xx = find(xx > x); % Returns a subset of [1,2,3];
GT_Time2State(iTime) = xx(1);
end
% GT_Time2Mu is an NTime by V matrix that lists the ground truth mean voxel
% intensities for each time point. GT_BoundaryIndex is a vector containing
% times of the the (in this case, K-1) state transitions.
GT_Time2Mu = GT_State2Feature(GT_Time2State,:);
GT_BoundaryIndex = find(diff(GT_Time2State));
% The Time2Voxel matrix lists each of the NTime observations for each of
% the V voxels.
% Note that we use an identity matrix to specify the covariance.
Time2Voxel = mvnrnd(GT_Time2Mu,eye(3),size(GT_Time2Mu,1));

Plot voxel time series and voxel-by-voxel scatter

% Plot the voxel intensity time series as a heat-map with time along the
% vertical axis and voxel along the horizontal axis.
figure;
subplot(1,4,1);
imagesc(Time2Voxel); colormap(hot);
xticks([1,2,3]);
xlabel('Voxel');
ylabel('Timepoint');
title('Voxel activity time series');
% Plot the intensities of each voxel in a 3D scatter plot with one point
% per time step (black dots). This illustrates a voxel-space depicting
% three clusters of observations that correspond to the 3 states. The
% scatter plot also includes large red points that mark the ground truth
% mean intensities for each state.
subplot(1,4,2:4);
hold on;
scatter3(Time2Voxel(:,1),Time2Voxel(:,2),Time2Voxel(:,3),30,'k','filled');
scatter3(...
GT_State2Feature(:,1),GT_State2Feature(:,2),GT_State2Feature(:,3),...
100,'red','filled');
axis square; xlim([-3,3]); ylim([-3,3]); zlim([-3,3]);
view([-40,12]);
title('Voxel-by-voxel scatter plot');

Plot timepoint-by-timepoint correlation & Euclidean distance

figure;
% Here we correlate the intensities of each voxel across pairs of time
% points. If two time points belong to the same state, their voxel
% intensities should be highly correlated. This manifests as clusters of
% high correlations close to the diagonal elements of the correlation matrix
% since, by definition, each state is most likely to be followed by itself
% (see GT_TransMat above).
subplot(1,2,1);
hold on;
imagesc(corr(Time2Voxel')); axis square;
colormap('jet');
CAxis01 = colorbar; caxis([-1,1]);
% The following four lines add white boundary boxes around the pairwise
% correlations that belong to the same ground truth states.
line([0;NTime],[GT_BoundaryIndex(1);GT_BoundaryIndex(1)],...
'Color','w','LineStyle','-','LineW',3);
line([GT_BoundaryIndex(1);GT_BoundaryIndex(1)],[0;NTime],...
'Color','w','LineStyle','-','LineW',3);
line([0;NTime],[GT_BoundaryIndex(2);GT_BoundaryIndex(2)],...
'Color','w','LineStyle','-','LineW',3);
line([GT_BoundaryIndex(2);GT_BoundaryIndex(2)],[0;NTime],...
'Color','w','LineStyle','-','LineW',3);
xlim([0,NTime]);
ylim([0,NTime]);
xlabel('Timepoint');
ylabel('Timepoint');
CAxis01.Label.String = 'Pearson correlation coefficient';
title(sprintf('Timepoint-by-timepoint%ccorrelation matrix',10));
% Following the same logic as above, we compute the pairwise Euclidean
% distances between time points across each voxel intensity. If two time
% points belong to the same state, the pairwise distance should be small
% manifesting as low values close to the diagonal elements of the matrix.
subplot(1,2,2);
hold on;
imagesc(squareform(pdist(Time2Voxel))); axis square;
colormap('jet');
CAxis01 = colorbar;
% The following four lines add white boundary boxes around the pairwise
% distances that belong to the same ground truth states.
line([0;NTime],[GT_BoundaryIndex(1);GT_BoundaryIndex(1)],...
'Color','w','LineStyle','-','LineW',3);
line([GT_BoundaryIndex(1);GT_BoundaryIndex(1)],[0;NTime],...
'Color','w','LineStyle','-','LineW',3);
line([0;NTime],[GT_BoundaryIndex(2);GT_BoundaryIndex(2)],...
'Color','w','LineStyle','-','LineW',3);
line([GT_BoundaryIndex(2);GT_BoundaryIndex(2)],[0;NTime],...
'Color','w','LineStyle','-','LineW',3);
xlim([0,NTime]);
ylim([0,NTime]);
xlabel('Timepoint');
ylabel('Timepoint');
CAxis01.Label.String = 'Euclidean distance';
title(sprintf('Timepoint-by-timepoint%cEuclidean distance',10));

Set some random starting parameters

% A is the estimated state-to-state transition matrix of size K by K.
% This is initialised such that their is an equal probability of
% transitioning from one state, to itself and all other states.
A = ones(3) ./ 3;
% M is a K by V+1 matrix that represents the estimated voxel
% representations for each of the K states. The first V columns of M
% represent the mean intensities for each of the voxels. The final
% column represents the variance of intensities around the mean,
% which is assumed to be the same of each of the V voxels.
M = [...
[0.01,0.02,0.03;
0.01,-0.02,0.03;
-0.01,0.02,-0.03],...
ones(3,1)];

Loop through a number iterations until convergence

Alternate between E- and M-steps. The E-Step involves estimating the probability of seeing each observation given current model estimated based on what occurs prior to that observation (the forward procedure) and what happens after that observation (the backward procedure). The M-Step involves recomputing the probability of being in state k at time t (matrix A), the K-by-K transition matrix, and the voxel representations stored in M.
% Set the current iteration counter:
Iter = 0;
% Set the maximum number of iterations:
MaxIter = 100;
% Preallocate an array for the Log-likelihood statistics:
LL = nan(MaxIter,1);
% % While the iteration counter is less than the maximum number
% of states:
while Iter < MaxIter
% Increment the iteration counter:
Iter = Iter + 1;
% Alpha and Beta are NTime by K state matrices. They each encode
% the probability of seeing each observation in each state given
% the probabilities for being in state K at time t (A matrix) and
% the estimated voxel representations in M. However, Alpha and Beta
% are computed using different procedures. Alpha is computed using
% the forward procedure to given an estimated probability based on
% all of the preceding states. Beta is computed using the backward
% procedure to give an estimated probability based on all subsequent
% states:
Alpha = nan(NTime,3);
Beta = nan(NTime,3);
% --- Creat B matrix:
% % This is an NTime by K matrix representing the probability that
% each observation was sampled from the K'th state. Is is
% calculated as the likelihood of each observation given the
% parameters for the k'th distribution in matrix M, normalised by
% the sum of all likelihoods across the K states.
B = [...
mvnpdf(Time2Voxel,M(1,1:3),eye(3).*M(1,4)), ...
mvnpdf(Time2Voxel,M(2,1:3),eye(3).*M(2,4)), ...
mvnpdf(Time2Voxel,M(3,1:3),eye(3).*M(3,4))];
B = B ./ sum(B,2);
% --- E-Step_Forward:
for t = 1:NTime
if t == 1
% In the forward procedure, we set the probability of
% seeing the first observation in the first state to be one.
Alpha(t,:) = [B(t,1),0,0];
else
for k = 1:3
% Probability of states given previous previous
% observations and transition probabilities.
% Note, this is a dot product:
p = Alpha(t-1,:) * A(:,k);
% Probability of observations being seen in each state
% given previous time and transition probabilities:
Alpha(t,k) = B(t,k) * p;
end
end
end
% --- E-Step_Backward:
for t = NTime:-1:1
if t == NTime
% In the backward procedure, we set the probability of
% seeing the last observation in the all states to be one.
Beta(t,:) = [1,1,1];
else
for k = 1:3
% Probability of seeing an observation t in each
% state given the next observation (t+1) and the
% transition probabilities:
p = A(k,:) .* B(t+1,:);
% Probability of seeing an observation t in each state
% given the all subsequent observations.
% Note, this is a dot product:
Beta(t,k) = Beta(t+1,:) * p';
end
end
end
% --- M-Step:
% Recompute the probability of being in each state at each time point:
Prob_Time2State = (Alpha.*Beta) ./ sum(Alpha.*Beta,2);
% Recompute the probability of moving to each state having been in
% state k (i.e. the A matrix):
Prob_Time2JointState = nan(3,3,NTime-1);
for t = 1:(NTime-1)
for k = 1:3
Prob_Time2JointState(k,:,t) = ...
Alpha(t,k) .* A(k,:) .* Beta(t+1,:) .* B(t+1,:);
end
Prob_Time2JointState = ...
Prob_Time2JointState ./ nansum(nansum(Prob_Time2JointState,2),1);
end
A = nansum(Prob_Time2JointState,3) ./ ...
repmat(nansum(Prob_Time2State,1)',1,3);
A = A ./ sum(A,2);
% Recompute the voxel representations in M:
for k = 1:3
M(k,1:3) = ...
nansum(Prob_Time2State(:,k).*Time2Voxel,1) ./ ...
nansum(Prob_Time2State(:,k));
end
% --- Calculate the model log-likelihood:
LogLikelihood = log([...
mvnpdf(Time2Voxel,M(1,1:3),eye(3).*M(1,4)), ...
mvnpdf(Time2Voxel,M(2,1:3),eye(3).*M(2,4)), ...
mvnpdf(Time2Voxel,M(3,1:3),eye(3).*M(3,4))]);
LogLikelihood = LogLikelihood + log(Prob_Time2State);
LogLikelihood = sum(log(sum(exp(LogLikelihood),2)),1);
LL(Iter) = LogLikelihood;
% If we have done at least 3 iterations, terminate the algorithm
% if the change in log-likelihood is very small:
if Iter > 3
if abs(LL(Iter)-LL(Iter-1)) < 1e-6
break
end
end
end
% Trim log-likelihood vector to remove NaNs:
LL = LL(~isnan(LL));

Plot LogLikelihood, Time2State & State2Voxel

figure;
% LogLikelihood as a function of iteration:
subplot(4,3,10:12);
plot((1:Iter)',LL,'Color',[0.5,0,1],'LineW',2);
xlabel('Iteration');
ylabel('Log-likelihood');
% Time-State matrix:
subplot(4,3,[1,4,7]);
hold on;
imagesc(Prob_Time2State);
colormap(bone); %colorbar;
% State 1
line([0.5,1.5],[0,0],...
'Color','red','LineStyle','-','LineW',2);
line([0.5,1.5],[GT_BoundaryIndex(1),GT_BoundaryIndex(1)],...
'Color','red','LineStyle','-','LineW',2);
line([0.5,0.5],[0,GT_BoundaryIndex(1)],...
'Color','red','LineStyle','-','LineW',2);
line([1.5,1.5],[0,GT_BoundaryIndex(1)],...
'Color','red','LineStyle','-','LineW',2);
% State 2
line([1.5,2.5],[GT_BoundaryIndex(1),GT_BoundaryIndex(1)],...
'Color','red','LineStyle','-','LineW',2);
line([1.5,2.5],[GT_BoundaryIndex(2),GT_BoundaryIndex(2)],...
'Color','red','LineStyle','-','LineW',2);
line([1.5,1.5],[GT_BoundaryIndex(1),GT_BoundaryIndex(2)],...
'Color','red','LineStyle','-','LineW',2);
line([2.5,2.5],[GT_BoundaryIndex(1),GT_BoundaryIndex(2)],...
'Color','red','LineStyle','-','LineW',2);
% State 3
line([2.5,3.5],[GT_BoundaryIndex(2),GT_BoundaryIndex(2)],...
'Color','red','LineStyle','-','LineW',2);
line([2.5,3.5],[NTime,NTime],...
'Color','red','LineStyle','-','LineW',2);
line([2.5,2.5],[GT_BoundaryIndex(2),NTime],...
'Color','red','LineStyle','-','LineW',2);
line([3.5,3.5],[GT_BoundaryIndex(2),NTime],...
'Color','red','LineStyle','-','LineW',2);
% Display options:
Ax01 = gca;
Ax01.YDir = 'reverse';
xticks([1,2,3]);
xlim([0.5,3.5]);
ylim([0,NTime]);
%xlabel('State number');
ylabel('Time point');
title('Time point to state');
subplot(4,3,[2,3,5,6,8,9]);
hold on;
scatter3(Time2Voxel(:,1),Time2Voxel(:,2),Time2Voxel(:,3),30,'k','filled');
scatter3(GT_State2Feature(:,1),GT_State2Feature(:,2),GT_State2Feature(:,3),...
100,'red','filled');
scatter3(M(1:3,1),M(1:3,2),M(1:3,3),400,[0.4,0.9,0.2],'LineWidth',4);
axis square; xlim([-3,3]); ylim([-3,3]); zlim([-3,3]);
view([-40,12]);
title('Voxel space');