function [MLE,HST] = HoopStats_RunEM(Responses,Targets,IncUniform,nVonmComps,MuVonmComps,ShowFigure)
% "HoopStats_RunEM"
% R02.00
% By Sam Berens (sam.berens@york.ac.uk).
%
% Estimates a circular mixture model using Expectation-Maximization (EM).
%
% [MLE,HST] = HoopStats_RunEM(...
%   Responses,Targets,IncUniform,nVonmComps,[MuVonmComps],[ShowFigure]);
%
% INPUT:
%   - "Responses" : an nx1 numeric denoting n angles (in radians) of
%                   circular responses;
%   - "Targets" : an nxt numeric denoting n angles (in radians) of t
%                 targets distributions. If sum(isnan(Targets),2) ==
%                 size(Targets,1), then response to target mappings are
%                 assumed to be mutually exclusive;
%   - "IncUniform" : a 1x1 logical specifying whether model should include
%                    a uniform distribution;
%   - "nVonmComps" : a 1x1 numeric denoting the number of non-target von
%                    Mises distributions to be modelled;
%   - "MuVonmComps" : [Optional] a cx1 numeric denoting the preferred
%                     direction (in radians) for each of c non-target von
%                     Mises distributions;
%   - "ShowFigure" : [Optional] a 1x1 numeric denoting the figure
%                    preference. If ShowFigure == 0, no figures will be
%                    drawn. If ShowFigure == 1, a figure of the final EM
%                    step will be drawn. If ShowFigure == 2, figures will
%                    be drawn to show each EM step;
%
% OUTPUT:
%   - "MLE" : The best fitting model;
%   - "HST" : Iteration history;

%% Define global variables:
global Responses_A Responses_X Responses_Y;
global ScatterPointRadius LineThickness;
global NRuns;
global EstimatePriors PriorBreakIn;

%% Initial set-up

% Calculate the number of free parameters (used in BIC computation below):
NfreeParams = size(Targets,2)*2 + double(IncUniform) - 1 ...
    + nVonmComps*2 - sum(~isnan(MuVonmComps));

% These are only used in the function PlotLastEmStep:
Responses_A = exp(wrapToPi(Responses - nansum(Targets,2)).*1i);
Responses_X = real(exp(Responses.*1i));
Responses_Y = imag(exp(Responses.*1i));

%% Estimation options:

% Number of EM starting positions:
NRuns = 17;

% Bool to specify whether we are estimating the priors or whether they
% remain fixed at equal values:
EstimatePriors = true;

% Bool to specify how much do the parameters Mu and K need to converge by
% before we start estimating estimate the priors:
PriorBreakIn = Inf; % Inf means that Priors are estimated immediately.

% Maximum number of iterations:
MaxIter = 10^3;

% Estimation tolerance for change in negative log-likelihood.
Tol_DeltaNLL = -1e-3; % This has to be negative!...
% Because we want to ensure that the model fit is getting better.

%% Plot appearance:
ScreenSize = get(0,'ScreenSize');
FigureSize = [1,ScreenSize(1)+40,ScreenSize(3),ScreenSize(4)-120];
ScatterPointRadius = 30;
LineThickness  = 3;

%% Loop through different EM starting positions
UniqueRuns = struct;
for iRun = 1:1:NRuns
    try % Try running the EM NRun number of times.
        
        % Set nTargets to be the number of target distributions:
        nTargets = size(Targets,2);
        
        % Specify starting parameters in ParamStruct:
        ParamStruct = DefineStart(iRun,nTargets,IncUniform,...
            nVonmComps,MuVonmComps);
        
        %% Copy the starting parametres into CurrentState:
        % Remember that the first entry of 'ParamStruct' (i = 1) contains
        % model starting parameters. As such, the data for iteration 'a' is
        % contained in index i = a + 1.
        CurrentState = struct;
        
        % For each target, copy staring values for Prior and K:
        for iTarget = 1:1:size(Targets,2)
            % Get the target identifier:
            TargetId = sprintf('T%i',iTarget-1);
            
            % Set the starting Prior:
            CurrentState.(TargetId).Prior = ...
                ParamStruct(1,1).(TargetId).Prior;
            
            % Set the starting K:
            CurrentState.(TargetId).K = ParamStruct(1,1).(TargetId).K;
        end
        
        % If the model includes a uniform distribution, set the prior for
        % this distribution as well as the observation likelihoods in
        % Pdents. These likelihoods never change during estimation
        % retaining a fixed value of 1/(2pi) for each observation:
        if IncUniform
            CurrentState.C0.Prior = ParamStruct(1,1).C0.Prior;
            CurrentState.C0.Pdents = ones(size(Responses,1),1).*(1/(2*pi));
        end
        
        % For each component, copy staring values for Prior and K:
        for iVonmComp = 1:1:nVonmComps
            % Get the component identifier:
            ComponentId = sprintf('C%i',iVonmComp);
            
            % Set the starting Prior:
            CurrentState.(ComponentId).Prior = ...
                ParamStruct(1,1).(ComponentId).Prior;
            
            % Set the starting mean angle:
            CurrentState.(ComponentId).Mu = ...
                ParamStruct(1,1).(ComponentId).Mu;
            
            % Set the starting K:
            CurrentState.(ComponentId).K = ...
                ParamStruct(1,1).(ComponentId).K;
        end
        
        %% Get the numbers and types of components being fitted:
        
        % Get a cell array of all the field names in ParamStruct. This
        % includes the names of the component distributions that need to be
        % pulled out:
        FieldNames = fieldnames(ParamStruct);
        
        % Produced a boolean selector that will select the cells of
        % FieldNames that refer to components distributions and assign them
        % to ComponentNames:
        Selector_C = cellfun(@(s) strcmp(s(1,1),'C'),FieldNames);
        ComponentNames = FieldNames(Selector_C);
        
        % Given the component names in ComponentNames, produce a vector of
        % integers to act as identifiers for the components:
        ComponentNums = cellfun(@(s) str2double(s(1,2)),ComponentNames);
        
        % Make sure ComponentNums is a column vector so that we can loop
        % through it using a for loop:
        ComponentNums = ComponentNums';
        
        %% Draw an initial figure:
        % If the user has selected to draw each EM step, show the starting
        % parameters:
        if ShowFigure == 2
            FigureH = figure;
            set(FigureH,'Position',FigureSize);
            set(FigureH,'color','white');
            PlotLastEmStep(ParamStruct,FigureH);
            pause(1); % Pause for a second.
        end
        
        %% Start EM:
        iIter = 0;
        Finished = false;
        Converged = false;
        while ~Finished
            
            % Incrament iteration count:
            iIter = iIter + 1;
            
            %% E-Step (First iteration only)
            % The step only needs to be run here on the first iteration
            % as it is called again after the M-step but before computing
            % the NLL:
            if iIter == 1
                CurrentState = EStep(CurrentState,...
                    Responses,Targets,ComponentNums,nVonmComps);
            end
            
            %% M-Step
            CurrentState = MStep(CurrentState,...
                Responses,Targets,ComponentNums,MuVonmComps);
            
            %% E-Step
            CurrentState = EStep(CurrentState,...
                Responses,Targets,ComponentNums,nVonmComps);
            
            %% Calculate the negative log-likelihood (NLL)
            % That is, the model fit:
            NLL = CalculateNLL(CurrentState,Responses,Targets,ComponentNums);
            
            %% Get DeltaParams
            % Find out how much each free parameter has changed in the last
            % EM iteration:
            [MaxDeltaParam_IncPriors,... Maximum change in all parameters.
                MaxDeltaParam_ExcPriors] = ... Change in Mu and K^-1 alone.
                GetMaxDeltaParam(CurrentState,...
                ParamStruct,Targets,ComponentNums);
            
            %% Update ParamStruct
            ParamStruct = UpdateParamStruct(...
                ParamStruct,iIter,CurrentState,...
                NLL,MaxDeltaParam_IncPriors,MaxDeltaParam_ExcPriors,...
                Targets,ComponentNums);
            
            %% Plot EM iteration if requested to do so
            if ShowFigure == 2
                PlotLastEmStep(ParamStruct,FigureH)
            end
            
            %% Terminate if "DeltaNLL" is less than tolerance
            % OR, if the iteration limit has been reached.
            
            % Compute the change in NLL, "DeltaNLL" ...
            % This will be negative when the model fit has improved and
            % positive when the model fit has become worse:
            DeltaNLL = ...
                ParamStruct(iIter+1,1).NLL - ParamStruct(iIter,1).NLL;
            
            % If DeltaNLL is greater than or equal to the Tolerance,
            % terminate estimation...
            % NOTE: This will occur if the model fit becomes worse (i.e.
            % "DeltaNLL" is positive):
            if DeltaNLL >= Tol_DeltaNLL
                Finished = true;
                
            elseif DeltaNLL < Tol_DeltaNLL
                % Estimation has converged so we set the "Converged" bool
                % to true:
                Converged = true;
                
            else
                % Do nothing!
                % This will happen on iIter = 1 because DeltaNLL and
                % ParamStruct(iIter,1).NLL will both be NaNs.
            end
            
            % Terminate if the iteration limit has been reached but DO NOT
            % set "Converged" to true:
            if iIter == MaxIter
                Finished = true;
            end
            
        end % End the While finished loop.
        
        %% Record the results of the last EM starting position:
        UniqueRuns(iRun,1).ParamStruct = ParamStruct;
        [UniqueRuns(iRun,1).NLL_Min,...
            UniqueRuns(iRun,1).NLL_MinIndex] = ...
            nanmin([ParamStruct.NLL],[],2);
        UniqueRuns(iRun,1).Converged = Converged;
        
        %% Close any remaining figures:
        if ShowFigure == 2
            close(FigureH);
        end
        
    catch
        %% Handel crashes in the EM procedure.
        % If the EM algorithm crashed for a particular run, set empty
        % values for the corresponding entries in UniqueRuns.
        UniqueRuns(iRun,1).ParamStruct = [];
        UniqueRuns(iRun,1).NLL_Min = NaN;
        UniqueRuns(iRun,1).NLL_MinIndex = NaN;
        
        % Set Converged bool for this run to be false.
        UniqueRuns(iRun,1).Converged = false;
        
        % Play a warning sound and print a warning message.
        Notification_A02;
        warning(sprintf(...
            'EM run failed to converge! %04d%02d%02d-%02d%02d%02d;',...
            fix(clock))); %#ok<SPWRN>
        
        % Close any remaining figures:
        if ShowFigure == 2
            close(FigureH);
        end
        
    end % End the try statement.
    
end % End the iRun loop

%% Deal with failures to converge:
% Failures to coNverge are not always accompanied by crashes in the EM
% algorithm. For each EM run, make sure that the NLL statistics for runs
% that did not converge are NaNs.
for iRun = 1:1:NRuns
    if ~UniqueRuns(iRun,1).Converged
        UniqueRuns(iRun,1).NLL_Min = NaN;
        UniqueRuns(iRun,1).NLL_MinIndex = NaN;
    end
end

%% Select the best EM run and the best iteration within it:
% First select run that returned the lowest NLL.
[~,BestStart_Index] = nanmin([UniqueRuns.NLL_Min],[],2);

% Given the best EM run, find the index of the iterations that
% corresponds to the minimum NLL:
BestIte_Index = UniqueRuns(BestStart_Index,1).NLL_MinIndex;

% If there is a best EM run with a "BestIte_Index" that is not a NaN,
% we know that the algorithm has converged for at least one run...
if ~isnan(BestIte_Index)
    % Here, the algorithm has converged and so we can set the HST
    % output to be a structured array, identical to the entries of
    % ParamStruct from: [a] the initial EM starting iteration to, [b]
    % the iteration with the lowest NLL value (indexed by
    % BestIte_Index). Remember that the EM algorithm is set to
    % terminate if the NLL increases from one iteration to the next. As
    % such, the last entry of ParamStruct may represent an increase in
    % the NLL. However, as we use the BestIte_Index variable to select
    % a range of entries in the ParamStruct, there is no danger that
    % the last entry of HST corresponds to an increase in the NLL
    % caused by runaway singularities.
    HST = UniqueRuns(BestStart_Index,1).ParamStruct(1:BestIte_Index,1);
    
    % Select the Maximum-likelihood estimate as the final entry of
    % HST... As noted, this is guaranteed to be the step across all
    % runs that has the lowest NLL value.
    MLE = HST(end,1);
    
    % Set Success to be true to denote that we have produced a
    % successful EM run.
    Success = true;
else
    % Here, the algorithm has not converged for any of the EM runs so
    % set the output variables to be empty.
    HST = [];
    MLE = [];
    Success = false;
end

%% Compute BIC statistics:
MLE.BIC_H0 = NaN;
MLE.BIC_H1 = NaN;
MLE.deltaBIC = NaN;
if Success % If we have had a successful EM run, compute deltaBIC.
    
    % First calculate the BIC for a null model inducing a uniform
    % alone. This BIC statistic is equivalent to the two times the NLL:
    BIC_H0 = 0 - (2 * log(1/(2*pi)) * size(Responses,1));
    
    % Calculate the BIC for the fitted model:
    BIC_H1 = (log(size(Responses,1))*NfreeParams) + 2*MLE.NLL;
    
    % Save BICs:
    MLE.BIC_H0 = BIC_H0;
    MLE.BIC_H1 = BIC_H1;
    MLE.deltaBIC = BIC_H1 - BIC_H0;
    
else
    MLE = [];
end

%% If the ShowFigure option is greater than zero, show the best EM run:
if Success && (ShowFigure > 0)
    FigureH = figure;
    set(FigureH,'Position',FigureSize);
    set(FigureH,'color','white');
    PlotLastEmStep(HST,FigureH);
    Notification_A01;
end

return

function [ParamStruct] = DefineStart(RunIndex,nTargets,IncUniform,nVonmComps,MuVonmComps)

global Responses_A NRuns;

% Set nResponses to be the number of responses:
nResponses = size(Responses_A,1);

% Set PriorSoFar which defines the cumulative prior probability that has
% been taken up by target/components disruptions as they are initialised.
% This changes systematically across runs giving a bigger initial prior to
% the first set of distributions (the targets) in later runs:
if RunIndex == 1
    % Note that, we cannot set any of the priors to be zero. As such, for
    % the first run, we set the prior for targets to be something very
    % small 1/nResponses for each target:
    PriorSoFar = nTargets / nResponses;
else
    % We increase the initial prior for targets linearly across runs up to
    % a value of NRuns/(NRuns+1):
    PriorSoFar = RunIndex / (NRuns+1);
end

%% Build ParamStruct:
ParamStruct = struct;
ParamStruct(1,1).Iteration = 0;

% Specify the target distributions:
for iTarget = 1:1:nTargets
    % Get the target identifier:
    TargetId = sprintf('T%i',iTarget-1);
    ParamStruct(1,1).(TargetId) = struct;
    
    % Set the distribution type:
    ParamStruct(1,1).(TargetId).DistributionType = 'vonMises';
    
    % Fix the Mu parameter to be zero for targets and start with a K of 2:
    ParamStruct(1,1).(TargetId).Mu = 0;
    ParamStruct(1,1).(TargetId).K = 2;
    
    % Divide up the PriorSoFar amongst the targets:
    ParamStruct(1,1).(TargetId).Prior = PriorSoFar / nTargets;
end

% Specify the Uniform component (if included):
if IncUniform
    ParamStruct(1,1).C0 = struct;
    % Set the distribution type:
    ParamStruct(1,1).C0.DistributionType = 'Uniform';
    
    % Equally divide what prior probability is not taken up by the targets
    % amongst the components:
    ParamStruct(1,1).C0.Prior = (1-PriorSoFar)/(1+nVonmComps);
    
    % Update PriorSoFar to keep track of what prior is left:
    PriorSoFar = PriorSoFar + ParamStruct(1,1).C0.Prior;
end

% Specify the von Mises components (if included).
% Produce some random starting values for the Mu parameter:
StartMu = sort(wrapToPi(rand(nVonmComps,1)*2*pi));
for iVonmComp = 1:1:nVonmComps
    % Get the component identifier:
    ComponentId = sprintf('C%i',iVonmComp);
    ParamStruct(1,1).(ComponentId) = struct;
    
    % Set the distribution type:
    ParamStruct(1,1).(ComponentId).DistributionType = 'vonMises';
    
    % When the Mu parameter is free to vary, MuVonmComps(iVonmComp) will be
    % NaN. If so, set them using the StartMu vector generated above:
    if isnan(MuVonmComps(iVonmComp))
        ParamStruct(1,1).(ComponentId).Mu = StartMu(iVonmComp);
    else
        % Otherwise, fix them to be the values supplied in MuVonmComps:
        ParamStruct(1,1).(ComponentId).Mu = MuVonmComps(iVonmComp);
    end
    
    % Start with a K of 2:
    ParamStruct(1,1).(ComponentId).K = 2;
    
    % Divide up the PriorSoFar amongst the components:
    ParamStruct(1,1).(ComponentId).Prior = (1-PriorSoFar) / nVonmComps;
end

%% Initialise the posterior weights (W) for each distribution to be NaNs:
for iTarget = 1:1:nTargets
    TargetId = sprintf('T%i',iTarget-1);
    ParamStruct(1,1).(['W_',TargetId]) = nan(nResponses,1);
end
if IncUniform
    ParamStruct(1,1).W_C0 = nan(nResponses,1);
end
for iVonmComp = 1:1:nVonmComps
    ComponentId = sprintf('C%i',iVonmComp);
    ParamStruct(1,1).(['W_',ComponentId]) = nan(nResponses,1);
end

%% Initialise the goodness-of-fit parameters to be NaNs:
ParamStruct(1,1).MaxDeltaParam = NaN;
ParamStruct(1,1).NLL = NaN;

return

function [CurrentState] = EStep(CurrentState,Responses,Targets,ComponentNums,nVonmComps)

% Compute the likelihood ("Pdents") that each response belongs to it's
% target distribution/s.
for iTarget = 1:1:size(Targets,2)
    
    % Get the target identifier as well as a selector that picks out the
    % corresponding rows in "Targets":
    TargetId = sprintf('T%i',iTarget-1);
    Selector = ~isnan(Targets(:,iTarget));
    
    % Compute the accuracy for each response to each target:
    Acc = wrapToPi(Responses(Selector) - Targets(Selector,iTarget));
    
    % Compute the likelihoods for a Von Mises with a mean accuracy of zero
    % and a concentration parameter as defined by the current model state:
    [CurrentState.(TargetId).Pdents,~] = ...
        HoopStats_VonmFit(0,CurrentState.(TargetId).K,Acc);
end

% Compute the likelihood ("Pdents") that each response belongs to the Von
% Mises components. This excludes the uniform (C0):
for iVonmComp = 1:1:nVonmComps
    
    % Get the component identifier:
    ComponentId = sprintf('C%i',iVonmComp);
    
    % Compute the likelihoods for a Von Mises with a mean and a
    % concentration parameter as defined by the current model state:
    [CurrentState.(ComponentId).Pdents,~] = ...
        HoopStats_VonmFit(CurrentState.(ComponentId).Mu,...
        CurrentState.(ComponentId).K,...
        Responses);
end

%% Compute E-Step denominator
% This is the sum of prior weighted likelihoods for both targets and
% components.
Denominator = zeros(size(Responses));

% Denominator for targets:
for iTarget = 1:1:size(Targets,2)
    
    % Get the target identifier as well as a selector that picks out the
    % corresponding rows in "Targets":
    TargetId = sprintf('T%i',iTarget-1);
    Selector = ~isnan(Targets(:,iTarget));
    
    % Assign the prior weighted likelihoods to a temporary variable before
    % adding it to the denominator:
    TermToAdd = zeros(size(Responses));
    TermToAdd(Selector) = ...
        CurrentState.(TargetId).Prior .* CurrentState.(TargetId).Pdents;
    Denominator = Denominator + TermToAdd;
end

% Denominator for components (including the Uniform if present):
for nComp = ComponentNums
    
    % Get the component identifier:
    ComponentId = sprintf('C%i',nComp);
    
    % Assign the prior weighted likelihoods to a temporary variable before
    % adding it to the denominator:
    TermToAdd = CurrentState.(ComponentId).Prior .* ...
        CurrentState.(ComponentId).Pdents;
    Denominator = Denominator + TermToAdd;
end

%% Compute posterior weights by normalising prior weighted likelihoods
% Uses the denominator computed above:

% Target posteriors:
for iTarget = 1:1:size(Targets,2)
    
    % Get the target identifier as well as a selector that picks out the
    % corresponding rows in "Targets":
    TargetId = sprintf('T%i',iTarget-1);
    Selector = ~isnan(Targets(:,iTarget));
    
    % Set the weights in "W" to be the prior weighted likelihoods
    % normalised by the denominator:
    CurrentState.(TargetId).W = zeros(size(Responses));
    CurrentState.(TargetId).W(Selector) = ...
        (CurrentState.(TargetId).Prior .* ...
        CurrentState.(TargetId).Pdents) ./ ...
        Denominator(Selector);
end

% Component posteriors (including the Uniform if present):
for nComp = ComponentNums
    
    % Get the component identifier:
    ComponentId = sprintf('C%i',nComp);
    
    % Set the weights in "W" to be the prior weighted likelihoods
    % normalised by the denominator:
    CurrentState.(ComponentId).W = (CurrentState.(ComponentId).Prior .* ...
        CurrentState.(ComponentId).Pdents) ./ Denominator;
end
return

function [CurrentState] = MStep(CurrentState,Responses,Targets,ComponentNums,MuVonmComps)
% *** Re-estimate the parameters of target distributions. ***

for iTarget = 1:1:size(Targets,2)
    
    % Get the target identifier as well as a selector that picks out the
    % corresponding rows in "Targets".
    TargetId = sprintf('T%i',iTarget-1);
    Selector = ~isnan(Targets(:,iTarget));
    
    % Re-estimate the prior as the expected value of the posterior weights
    % in "W".
    CurrentState.(TargetId).Prior = ...
        mean(CurrentState.(TargetId).W(Selector),1);
    
    % Compute the accuracy for each response to each target:
    Acc = wrapToPi(Responses(Selector) - Targets(Selector,iTarget));
    
    % "WeightedVector" is all resultant vector on the complex plane. It is
    % computed as an average of all response accuracies in "Acc" weighted by
    % their posterior loading onto the target distribution (specified by
    % "W").
    %
    % First, complexify the response accuracies in Acc.
    WeightedVector = sum(exp(Acc.*1i) .* ...
        ... Then, weight them by "W".
        CurrentState.(TargetId).W(Selector)) ./...
        ... And finally, normalise by the weights.
        sum(CurrentState.(TargetId).W(Selector));
    
    % Convert "WeightedVector" into a Von Mises concentration parameter
    % (K). Here, we only use the real component because we assume the
    % errors are evenly distributed around zero and so the real projection
    % is all we really care about.
    CurrentState.(TargetId).K = HoopStats_R2K(real(WeightedVector));
    
    % If K is less than zero, set it to zero. This can happen since only
    % real part is used.
    if CurrentState.(TargetId).K < 0
        CurrentState.(TargetId).K = 0;
    end
    
    % Correct estimates of K to minimise bias when p is small (see Best &
    % Fisher, 1981).
    CurrentState.(TargetId).K = HoopStats_KCorrection(numel(Acc),...
        CurrentState.(TargetId).Prior, CurrentState.(TargetId).K);
end

%% *** Re-estimate the parameters of non-target components ***
% Includes a potential uniform (C0).
for nComp = ComponentNums
    
    % Get the component identifier:
    ComponentId = sprintf('C%i',nComp);
    
    % Re-estimate the prior as the expected value of the posterior weights
    % in "W".
    CurrentState.(ComponentId).Prior = ...
        mean(CurrentState.(ComponentId).W,1);
    
    % If we are not dealing with a uniform:
    if nComp ~= 0
        % And if we are allowing the means of each component to
        % be freely estimated:
        if isnan(MuVonmComps(nComp,1))
            
            % Re-estimate the component parameters Mu and K.
            [CurrentState.(ComponentId).Mu,... MU
                CurrentState.(ComponentId).K] = ... K
                HoopStats_CalcParams(Responses,... Responses
                CurrentState.(ComponentId).W); % And posteriors
        else
            % Else, if we are not allowing the mean to freely
            % vary, just estimate K alone.
            [~,CurrentState.(ComponentId).K] = HoopStats_CalcParams(...
                Responses,CurrentState.(ComponentId).W);
        end
    end
end
return

function [NLL] = CalculateNLL(CurrentState,Responses,Targets,ComponentNums)
% Calculate the negative log-likelihood of the model (NLL):

% "iDist" loops through all targets distributions and all components
% (including a uniform if present).
iDist = 0;

% Likelihoods weighted by posterior probabilities.
Lpost = nan(size(Responses,1),... nRows = number of responses.
    size(Targets,2)+numel(ComponentNums));
% nCols = number of distributions (including targets).

% Likelihoods weighted by prior probabilities.
Lprio = nan(size(Responses,1),... nRows = number of responses.
    size(Targets,2)+numel(ComponentNums));
% nCols = number of distributions (including targets).

% Unweighted Likelihoods.
Lunwe = nan(size(Responses,1),... nRows = number of responses.
    size(Targets,2)+numel(ComponentNums));
% nCols = number of distributions (including targets).

%% Loop through each of the target distributions:
for iTarget = 1:1:size(Targets,2)
    
    % Increment "iDist" to track where we are in L*(:,iDist).
    iDist = iDist + 1;
    
    % Get the target identifier as well as a selector that picks out the
    % corresponding rows in "Targets":
    TargetId = sprintf('T%i',iTarget-1);
    Selector = ~isnan(Targets(:,iTarget));
    
    % Likelihoods weighted by posterior probabilities.
    Lpost(Selector,iDist) = CurrentState.(TargetId).Pdents ...
        .* CurrentState.(TargetId).W(Selector);
    
    % Likelihoods weighted by prior probabilities.
    Lprio(Selector,iDist) = CurrentState.(TargetId).Pdents ...
        .* CurrentState.(TargetId).Prior;
    
    % Unweighted Likelihoods.
    Lunwe(Selector,iDist) = CurrentState.(TargetId).Pdents;
end

%% Loop through each component distributions (including the uniform):
for nComp = ComponentNums
    % Increment "iDist" to track where we are in L*(:,iDist).
    iDist = iDist + 1;
    
    % Get the component identifier:
    ComponentId = sprintf('C%i',nComp);
    
    % Likelihoods weighted by posterior probabilities.
    Lpost(:,iDist) = CurrentState.(ComponentId).Pdents ...
        .* CurrentState.(ComponentId).W;
    
    % Likelihoods weighted by prior probabilities.
    Lprio(:,iDist) = CurrentState.(ComponentId).Pdents ...
        .* CurrentState.(ComponentId).Prior;
    
    % Unweighted Likelihoods.
    Lunwe(:,iDist) = CurrentState.(ComponentId).Pdents;
end

%% Calculate NLL:
% Sum the log likelihoods and multiply it by -1 such that smaller values
% indicates a better model fit:
NLL = -1 * sum(log(nansum(Lpost,2)),1);
return

function [MaxDeltaParam_IncPriors,MaxDeltaParam_ExcPriors] = GetMaxDeltaParam(CurrentState,ParamStruct,Targets,ComponentNums)
% Find out how much each free parameter has changed in the last EM step:

% Initialise "iEntry" which indexes each free parameters in the model
% represented by the rows of "DeltaParams":
iEntry = 0;
DeltaParams = nan(size(Targets,2)+numel(ComponentNums),3);
% The columns of "DeltaParams" represent:
%   1) Mu;
%   2) K^-1;
%   3) Prior;

% Loop through the target distribution parameters:
for iTarget = 1:1:size(Targets,2)
    
    iEntry = iEntry + 1; % Increment "iEntry".
    TargetId = sprintf('T%i',iTarget-1); % Get the target identifier.
    
    % Find the absolute value of the difference in K^-1 and the Prior between
    % the current model state, and the previous state held in
    % ParamStruct(end,1).
    DeltaParams(iEntry,3) = abs(CurrentState.(TargetId).Prior - ...
        ParamStruct(end,1).(TargetId).Prior);
    DeltaParams(iEntry,2) = abs((1/CurrentState.(TargetId).K) - ...
        (1/ParamStruct(end,1).(TargetId).K));
end

% Loop through the component distribution parameters:
for nComp = ComponentNums
    
    iEntry = iEntry + 1; % Increment "iEntry".
    ComponentId = sprintf('C%i',nComp); % Get the component identifier
    
    % If we are not dealing with the uniform:
    if nComp ~= 0
        % Find the absolute value of difference in Mu and K^-1 between the
        % current model state, and the previous state held in
        % ParamStruct(end,1).
        DeltaParams(iEntry,1) = abs(CurrentState.(ComponentId).Mu - ...
            ParamStruct(end,1).(ComponentId).Mu);
        DeltaParams(iEntry,2) = abs((1/CurrentState.(ComponentId).K) - ...
            (1/ParamStruct(end,1).(ComponentId).K));
    end
    
    % Find the absolute value of the difference in the Prior between the
    % current model state, and the previous state held in
    % ParamStruct(end,1).
    DeltaParams(iEntry,3) = abs(CurrentState.(ComponentId).Prior - ...
        ParamStruct(end,1).(ComponentId).Prior);
end

% Replace Infs with NaN (generated when K = 0);
DeltaParams(~isfinite(DeltaParams)) = NaN;

% Get "MaxDeltaParam_IncPriors":
MaxDeltaParam_IncPriors = nanmax(DeltaParams(:));

% Get "MaxDeltaParam_ExcPriors":
MaxDeltaParam_ExcPriors = DeltaParams(:,1:2);
MaxDeltaParam_ExcPriors = nanmax(MaxDeltaParam_ExcPriors(:));
return

function [ParamStruct] = UpdateParamStruct(ParamStruct,iIter,CurrentState,NLL,MaxDeltaParam_IncPriors,MaxDeltaParam_ExcPriors,Targets,ComponentNums)
% Updates ParamStruct.
global EstimatePriors PriorBreakIn;

% Input iteration counter
ParamStruct(iIter+1,1).Iteration = iIter;

% Update target parameters:
for iTarget = 1:1:size(Targets,2)
    
    % Get the target identifier
    TargetId = sprintf('T%i',iTarget-1);
    
    % Mu is always fixed at zero for targets:
    ParamStruct(iIter+1,1).(TargetId).Mu = 0;
    
    % Update K:
    ParamStruct(iIter+1,1).(TargetId).K = CurrentState.(TargetId).K;
    
    % Update posterior weights:
    ParamStruct(iIter+1,1).(['W_',TargetId]) = CurrentState.(TargetId).W;
    
    % If we are estimating priors, and the "MaxDeltaParam_ExcPriors" is
    % less than the PriorBreakIn, save the priors that we computed in the
    % last M-Step:
    if EstimatePriors && (MaxDeltaParam_ExcPriors < PriorBreakIn)
        ParamStruct(iIter+1,1).(TargetId).Prior = ...
            CurrentState.(TargetId).Prior;
    else
        % Otherwise, the priors are simply copied over from the previous
        % step:
        ParamStruct(iIter+1,1).(TargetId).Prior = ...
            ParamStruct(iIter,1).(TargetId).Prior;
    end
end

% Update component parameters:
for nComp = ComponentNums
    
    % Get the component identifier
    ComponentId = sprintf('C%i',nComp);
    
    % If we are not dealing with the uniform, update Mu and K Note that if
    % Mu is fixed by MuVonmComps, it would not have been changed by the
    % M-Step function:
    if nComp ~= 0
        ParamStruct(iIter+1,1).(ComponentId).Mu = ...
            CurrentState.(ComponentId).Mu;
        ParamStruct(iIter+1,1).(ComponentId).K = ...
            CurrentState.(ComponentId).K;
    end
    
    % Update posterior weights:
    ParamStruct(iIter+1,1).(['W_',ComponentId]) = ...
        CurrentState.(ComponentId).W;
    
    % If we are estimating priors, and the "MaxDeltaParam_ExcPriors" is
    % less than the PriorBreakIn, save the priors that we computed in the
    % last M-Step:
    if EstimatePriors && (MaxDeltaParam_ExcPriors < PriorBreakIn)
        ParamStruct(iIter+1,1).(ComponentId).Prior = ...
            CurrentState.(ComponentId).Prior;
    else
        % Otherwise, the priors are simply copied over from the previous
        % step:
        ParamStruct(iIter+1,1).(ComponentId).Prior = ...
            ParamStruct(iIter,1).(ComponentId).Prior;
    end
end

% Save MaxDeltaParam and the NLL:
ParamStruct(iIter+1,1).MaxDeltaParam = MaxDeltaParam_IncPriors;
ParamStruct(iIter+1,1).NLL = NLL;

return

function [] = PlotLastEmStep(ParamStruct,FigureH)

% Import some global variables:
global Responses_A Responses_X Responses_Y
global ScatterPointRadius LineThickness
global NoVonmC;

%% Open figure and set focus on scatter circle:
set(0,'CurrentFigure',FigureH);
clf;
subplot(2,2,[1,3]);
DrawCircle;
hold on;

%% Get PointColours and Stereotypes:
% Fields is a list of all the fields in ParamStruct:
Fields = fieldnames(ParamStruct);

% W will be a r*d matrix of posterior weights encoding the probability that
% each response r belongs to each target/component distribution d.
W = [];

% Loop through all fields in ParamStruct and pull out all the posterior
% weights in the fields named W*:
for iFields = 1:1:size(Fields,1)
    if strcmp(Fields{iFields,1}(1,1),'W')
        
        % Add a new column to W from the last EM iteration in
        % ParamStruct(end,1).
        W = [W,ParamStruct(end,1).(Fields{iFields,1})]; %#ok<AGROW>
    end
end

% Compute 2 matrices, PointColours_RGB & Stereotypes_RGB.
% PointColours_RGB is an r by 3 matrix of encoding the colour of r responses
% with an RGB triplet. Stereotypes_RGB is a d by 3 matrix encoding the pure
% colour of each distribution d in RGB colour space. A response in
% PointColours_RGB will only equal one of these pure colours if the
% posterior probability of that response is equal to 1 for a particular
% distribution.
[PointColours_RGB,Stereotypes_RGB] = MakeColours(W);

%% Scatter responses:
if ~NoVonmC
    % If there is at least one von Mises component, draw each response
    % around a circular scatter plot to represents the world-centered
    % location of that response.
    scatter(Responses_X,Responses_Y,ScatterPointRadius,PointColours_RGB,...
        'filled','MarkerEdgeColor',[0,0,0],'LineWidth',1);
    
    % Lable angles:
    text(0,1.1,'\pi/2'); % +pi/2 a the top.
    text(-1.15,0,'\pi'); % pi to the left.
    text(0,-1.1,'-\pi/2'); %-pi/2 at the bottom.
    text(1.1,0,'0'); % Zero to the right.
else
    % If there are no von Mises components, each response is drawn round
    % the circle at a location that represents the angular error between
    % the response and the target (i.e. target-centered):
    Responses_AX = real((Responses_A.*1i));
    Responses_AY = imag((Responses_A.*1i));
    scatter(Responses_AX,Responses_AY,ScatterPointRadius,...
        PointColours_RGB,'filled','MarkerEdgeColor',[0,0,0],'LineWidth',1);
    % Note, that we have multiplied the error-centred locations in
    % Responses_X and Responses_Y by 1i... This rotates all the responses
    % such that a zero error corresponds to a position at the top of the
    % circle.
    
    % Lable angles:
    text(0,1.1,'0'); % Zero at the top..
    text(-1.15,0,'\pi/2'); % +pi/2 to the left.
    text(0,-1.1,'\pi'); % pi at the bottom.
    text(1.1,0,'-\pi/2'); % -pi/2 to the right.
end

% Adjust the appearance of the scatter plot:
axis equal;
axis off;
text(0.5,1,sprintf('Iteration: %i;',ParamStruct(end,1).Iteration),...
    'FontSize',10);

%% Define: DistNames, ParamsToDisplay, & InfoContent AND plot mean vectors:

iDist = 0; % Counts the distributions.
Priors = nan(0,1); % Added to ParamsToDisplay blow.
DistNames = cell(0,1); % Names each distribution.
ParamsToDisplay = nan(0,2); % Stores the Mu and K of each distribution.
InfoContent = nan(0,1); % Stores the InfoContent of each distribution.

% Loop through all the fields of ParamStruct to populate these variables
% for each iDist distribution.
for iFields = 1:1:size(Fields,1)
    if strcmp(Fields{iFields,1}(1,1),'T')
        %% If this field relates to a target...
        % Save and set the parameters of this distribution:
        iDist = iDist + 1; % Increment the counter.
        DistNames{iDist,1} = Fields{iFields,1};
        Prior = ParamStruct(end,1).(DistNames{iDist,1}).Prior;
        K = ParamStruct(end,1).(DistNames{iDist,1}).K;
        Priors(iDist,1) = Prior;
        ParamsToDisplay(iDist,2) = K;
        InfoContent(iDist,1) = HoopStats_InfoContent(Prior,K);
        
        % If there are no von Mises components in the model, draw the
        % resultant vector of the target distribution:
        if NoVonmC
            plot([0;0],[0;HoopStats_K2R(K)],... Fixed to have a 0 angle.
                'Color',Stereotypes_RGB(iDist,:),...
                'LineStyle','-',...
                'LineWidth',LineThickness);
        end
        
    elseif strcmp(Fields{iFields,1}(1,1),'C')
        %% If this field relates to a component...
        iDist = iDist + 1; % Increment the counter.
        DistNum = str2double(Fields{iFields,1}(1,2)); % Get its number.
        DistNames{iDist,1} = Fields{iFields,1}; % Save its name.
        Priors(iDist,1) = Prior; % Save its Prior.
        
        if DistNum == 0
            % If this field relates to the uniform component...
            % Save and set the parameters of this distribution:
            Prior = ParamStruct(end,1).C0.Prior;
            ParamsToDisplay(iDist,1) = NaN;
            ParamsToDisplay(iDist,2) = NaN;
            InfoContent(iDist,1) = Prior * log(2*pi);
            
            % Draw a circle on the scatter plot representing the hight of
            % the uniform distribution:
            Thetas = linspace(0,2*pi);
            X = cos(Thetas).*Prior;
            Y = sin(Thetas).*Prior;
            plot(X,Y,'Color',Stereotypes_RGB(iDist,:),...
                'LineWidth',LineThickness./2);
            
        else
            % If this field relates to the von Mises component...
            % Save and set the parameters of this distribution:
            Mu = ParamStruct(end,1).(DistNames{iDist,1}).Mu;
            K = ParamStruct(end,1).(DistNames{iDist,1}).K;
            Prior = ParamStruct(end,1).(DistNames{iDist,1}).Prior;
            ParamsToDisplay(iDist,1) = Mu;
            ParamsToDisplay(iDist,2) = K;
            InfoContent(iDist,1) = HoopStats_InfoContent(Prior,K);
            
            % Compute the resultant vector for this component:
            Mag = HoopStats_K2R(K);
            Vector = exp(Mu*1i) * Mag;
            Vector_X = real(Vector);
            Vector_Y = imag(Vector);
            
            % Draw the resultant vector:
            plot([0;Vector_X],[0;Vector_Y],...
                'Color',Stereotypes_RGB(iDist,:),...
                'LineStyle','-',...
                'LineWidth',LineThickness);
        end
    end
end
ParamsToDisplay = [ParamsToDisplay,Priors];

%% Draw a bar graph of information content statistics:
subplot(2,2,2);
for iInfoContent = 1:1:size(InfoContent,1)
    % Draw the bar for each distribution one at a time:
    BarHandle = bar(iInfoContent,InfoContent(iInfoContent,1));
    if iInfoContent == 1
        hold on; % Hold on if drawing the first distribution.
    end
    
    % Set the colour of each bar to be the RGB stereotype for that
    % distribution.
    set(BarHandle,'FaceColor',Stereotypes_RGB(iInfoContent,:));
end

% Draw the parameters of each distribution on the bar graph:
for iParamsToDisplay = 1:1:size(ParamsToDisplay,1)
    
    % Set the X and Y positions of the text:
    Text_X = iParamsToDisplay;
    Text_Y = -0.5;
    
    % Get the parameters themselves:
    Mu = ParamsToDisplay(iParamsToDisplay,1);
    K = ParamsToDisplay(iParamsToDisplay,2);
    Prior = ParamsToDisplay(iParamsToDisplay,3);
    
    if isnan(Mu) && isnan(K)
        % If Mu and K are both NaNs, we must be dealing with a uniform
        % component. In this case, simply display the prior:
        text(Text_X,Text_Y,...
            ['\it{p} = ',sprintf('%0.2f',Prior),';'],...
            'HorizontalAlignment','center');
        
    elseif isnan(Mu)
        % If only Mu is a NaN, we must be dealing with a target
        % distribution. In this case, display K and the prior:
        text(Text_X,Text_Y,...
            ['\it{k} = ',sprintf('%0.2f',K),';',10,...
            '\it{p} = ',sprintf('%0.2f',Prior),';'],...
            'HorizontalAlignment','center');
        
    else
        % If neither Mu or K are NaNs, we must be dealing with a von Mises
        % component. In this case, display Mu, K and the prior:
        text(Text_X,Text_Y,...
            ['\mu = ',sprintf('%0.2f',Mu),';',10,...
            '\it{k} = ',sprintf('%0.2f',K),';',10,...
            '\it{p} = ',sprintf('%0.2f',Prior),';'],...
            'HorizontalAlignment','center');
    end
end

% Adjust the appearance of the bar graph:
title('Parameter estimates');
ylabel('Information content (nats)');
ylim([-1,2.2]);
set(gca,'XTick',1:size(DistNames,1));
set(gca,'XTickLabel',DistNames');
plot(xlim,[log(2*pi),log(2*pi)],'LineStyle','--',...
    'LineWidth',LineThickness./4,'Color','k');

%% Draw the iteration history:
subplot(2,2,4);
XaxisData = (1:size(ParamStruct,1))';

% Even if there are less than 100 iterations, we draw the history on a plot
% with enough space for 100 iterations. This is done by padding the X and Y
% axis data with enough NaNs to make a vector of length 100. Here we first
% compute how many NaNs need to be added:
Y2Add = nan(0,1);
if size(ParamStruct,1) < 100
    N2Add = 100 - size(ParamStruct,1);
    XaxisData = (1:1:100)'; % Redefine XaxisData so that it goes to 100.
    Y2Add = nan(N2Add,1);
end


% Plot both the NLL and log10(MaxDeltaParam):
[HFit_Ax,HFit_L1,HFit_L2] = plotyy(...
    XaxisData,[[ParamStruct.NLL]';Y2Add],...
    XaxisData,[log10([ParamStruct.MaxDeltaParam]');Y2Add]); %#ok<ASGLU>

% Adjust the appearance of the line graph:
title('Iteration history');
ylabel(HFit_Ax(1),'Negative log-likelihood');
ylabel(HFit_Ax(2),'Log_1_0 max \Delta params');
xlabel('Iteration number');

% Set the Y Tick marks for the log10(MaxDeltaParam) plot:
HFit_Ax(1,2).YTick = flip([0,-2,-4,-6,-8,-10]);

% If the number of iterations is less than 100, adjust the x-axis limits:
if size(ParamStruct,1) < 100
    HFit_Ax(1,1).XLim = [0,100];
    HFit_Ax(1,2).XLim = [0,100];
end

% Pause for a little bit and hold off:
pause(0.05);
hold off;

return

function [] = DrawCircle()
% Draw a unit circle on whatever plot has focus:
global LineThickness;
Thetas = linspace(0,2*pi);
X = cos(Thetas);
Y = sin(Thetas);
plot(X,Y,'k-','LineWidth',LineThickness./3);
return

function [PointColours_RGB,Stereotypes_RGB,FSDisp] = MakeColours(W)
% Compute 2 matrices, PointColours_RGB & Stereotypes_RGB.
% PointColours_RGB is a r by 3 matrix of encoding the colour of r responses
% with an RGB triplet. Stereotypes_RGB is a d by 3 matrix encoding the pure
% colour of each distribution d in RGB colour space. A response in
% PointColours_RGB will only equal one of these pure colours if the
% posterior probability of that response is equal to 1 for a particular
% distribution.

% If the input matrix is full of NaN, assign equal weights to all
% distributions across all responses.
if sum(sum(isnan(W))) == numel(W)
    W = ones(size(W)) ./ size(W,2);
else
    % Otherwise, replace individual NaNs with zeros:
    W(isnan(W)) = 0;
end


%% Compute the Stereotypes_Angles:
% Set the total number of colours to be the number of distributions:
nColours = size(W,2);

% The range and offset variables code how much of the colour wheel is used
% in sampling the stereotypes colours and what the colour of the first
% stereotypes will be (0 = red). Given this offset, we sample colour angles
% equidistantly from one another to get the remaining stereotypes.
Range = 1;
Offset = 0;
Stereotypes_Angles = ...
    mod((((Range/nColours).*((1:1:nColours)-1))+Offset),1)' .* (2*pi);

%% Compute Stereotypes_RGB:
% Produce an HSV array for the distributions (stereotypes) and convert it
% into a matrix of RGB triplets:
Stereotypes_HSV = [Stereotypes_Angles./(2*pi),...
    ones(size(Stereotypes_Angles)),... The saturation is always set to 1.
    ones(size(Stereotypes_Angles))]; % The value is always set to 1.
Stereotypes_RGB = hsv2rgb(Stereotypes_HSV);

% FSDisp is an anonymous function that can display the stereotype colours:
FSDisp = @() image(reshape(Stereotypes_RGB,size(Stereotypes_RGB,1),1,3));

%% Compute PointColours_Angles (Hue) and PointColours_Mags (Saturation):
% First, perform a matrix multiplication to an r by 1 complex vector
% representing a weighted sum of stereotypes positions inside the unit
% circle that perfectly describe relative loadings of each response r onto
% each of the distributions. Remember that the stereotypes lie on the edge
% of the unit circle and because the posterior probabilities in W will add
% up to one, we can guarantee that the magnitude of each complex value will
% be less than 1 (via the triangle inequality). The phase of these complex
% values will define the hue of each point and the magnitude will define
% the saturation.
PointColours_Complex = W * exp(Stereotypes_Angles*1i);
PointColours_Angles = angle(PointColours_Complex);
PointColours_Mags = abs(PointColours_Complex);

% If any of the angles are less than zero, add 2pi to then to make them
% positive:
PointColours_Angles(PointColours_Angles<0) = ...
    (2*pi) + PointColours_Angles(PointColours_Angles<0);

% Now squish the angles into the range [0,1] as this is expected by the
% function hsv2rgb:
PointColours_Angles  = PointColours_Angles ./ (2*pi);

%% Compute PointColours_RGB:
% Produce an HSV array for the responses and convert it into a matrix of
% RGB triplets:
PointColours_HSV = [PointColours_Angles,...
    PointColours_Mags,...
    ones(size(PointColours_Complex,1),1)]; % The value is always set to 1.
PointColours_RGB = hsv2rgb(PointColours_HSV);
return

function [] = Notification_A01()
% Make notification sound A01... Happy (i.e. something good happened).

Fs = 48000; % Sampling frequency.
Vol = 0.4; % Max amplitude.
L1 = 0.1; % Length 1 in seconds;
P1 = 1000; % Frequency 1 in Hz.
L2 = 0.1; % Length 2 in seconds;
P2 = 1200; % Frequency 2 in Hz.

% Call MakeSound to produce the sound samples S1 and S2:
[S1] = MakeSound(Fs,L1,P1,Vol);
[S2] = MakeSound(Fs,L2,P2,Vol);

% Play a sequence of S1 and S2:
sound(S1,Fs);
pause(L1);
sound(S2,Fs);
pause(L2);
sound(S2,Fs);
pause(L1);
sound(S1,Fs);
return

function [] = Notification_A02()
% Make notification sound A02... Sad (i.e. something bad happened).

Fs = 48000; % Sampling frequency.
Vol = 0.4; % Max amplitude.
L1 = 0.1; % Length 1 in seconds;
P1 = 1200; % Frequency 1 in Hz.
L2 = 0.1; % Length 2 in seconds;
P2 = 900; % Frequency 2 in Hz.

% Call MakeSound to produce the sound samples S1 and S2:
[S1] = MakeSound(Fs,L1,P1,Vol);
[S2] = MakeSound(Fs,L2,P2,Vol);

% Play a sequence of S1 and S2:
sound(S1,Fs);
pause(L1);
sound(S2,Fs);
pause(L2);
sound(S2,Fs);
pause(L1);
sound(S2,Fs);
return

function [Sample] = MakeSound(Fs,Length,Pitch,Vol)
% Make a sinusoidal sound sample.

% Set the fade in/out time (in seconds):
FadeTime = 0.07;

% Make the sound sample:
Sample = 0:(Pitch/Fs):(Pitch*Length);
Sample = sin(Sample.*2.*pi);

% Generate some fade in and fade out envelopes:
InEnv = 0:(1/(Fs*FadeTime)):1;
OutEnv = 1:-(1/(Fs*FadeTime)):0;

% Multiply the envelopes by the sample:
Sample(1:numel(InEnv)) = Sample(1:numel(InEnv)) .* InEnv;
Sample(end-numel(InEnv)+1:end) = Sample(end-numel(InEnv)+1:end) .* OutEnv;

% Globally adjust the amplitude of the sample:
Sample = Sample .* Vol;
return