forked from mobeets/fa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fastfa.m
126 lines (112 loc) · 3.14 KB
/
fastfa.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
function [estParams, LL] = fastfa(X, zDim, varargin)
%
% [estParams, LL] = fastfa(X, zDim, ...)
%
% Factor analysis and probabilistic PCA.
%
% xDim: data dimensionality
% zDim: latent dimensionality
% N: number of data points
%
% INPUTS:
%
% X - data matrix (xDim x N)
% zDim - number of factors
%
% OUTPUTS:
%
% estParams.L - factor loadings (xDim x zDim)
% estParams.Ph - diagonal of uniqueness matrix (xDim x 1)
% estParams.d - data mean (xDim x 1)
% LL - log likelihood at each EM iteration
%
% OPTIONAL ARGUMENTS:
%
% typ - 'fa' (default) or 'ppca'
% tol - stopping criterion for EM (default: 1e-8)
% cyc - maximum number of EM iterations (default: 1e8)
% minVarFrac - fraction of overall data variance for each observed dimension
% to set as the private variance floor. This is used to combat
% Heywood cases, where ML parameter learning returns one or more
% zero private variances. (default: 0.01)
% (See Martin & McDonald, Psychometrika, Dec 1975.)
% verbose - logical that specifies whether to display status messages
% (default: false)
%
% Code adapted from ffa.m by Zoubin Ghahramani.
%
% @ 2009 Byron Yu -- [email protected]
typ = 'fa';
tol = 1e-8;
cyc = 1e8;
minVarFrac = 0.01;
verbose = false;
assignopts(who, varargin);
randn('state', 0);
[xDim, N] = size(X);
% Initialization of parameters
cX = cov(X', 1);
if rank(cX) == xDim
scale = exp(2*sum(log(diag(chol(cX))))/xDim);
else
% cX may not be full rank because N < xDim
fprintf('WARNING in fastfa.m: Data matrix is not full rank.\n');
r = rank(cX);
e = sort(eig(cX), 'descend');
scale = geomean(e(1:r));
end
L = randn(xDim,zDim)*sqrt(scale/zDim);
Ph = diag(cX);
d = mean(X, 2);
varFloor = minVarFrac * diag(cX);
I = eye(zDim);
const = -xDim/2*log(2*pi);
LLi = 0;
LL = [];
for i = 1:cyc
% =======
% E-step
% =======
iPh = diag(1./Ph);
iPhL = iPh * L;
MM = iPh - iPhL / (I + L' * iPhL) * iPhL';
beta = L' * MM; % zDim x xDim
cX_beta = cX * beta'; % xDim x zDim
EZZ = I - beta * L + beta * cX_beta;
% Compute log likelihood
LLold = LLi;
ldM = sum(log(diag(chol(MM))));
LLi = N*const + N*ldM - 0.5*N*sum(sum(MM .* cX));
if verbose
fprintf('EM iteration %5i lik %8.1f \r', i, LLi);
end
LL = [LL LLi];
% =======
% M-step
% =======
L = cX_beta / EZZ;
Ph = diag(cX) - sum(cX_beta .* L, 2);
if isequal(typ, 'ppca')
Ph = mean(Ph) * ones(xDim, 1);
end
if isequal(typ, 'fa')
% Set minimum private variance
Ph = max(varFloor, Ph);
end
if i<=2
LLbase = LLi;
elseif (LLi < LLold)
disp('VIOLATION');
elseif ((LLi-LLbase) < (1+tol)*(LLold-LLbase))
break;
end
end
if verbose
fprintf('\n');
end
if any(Ph == varFloor)
fprintf('Warning: Private variance floor used for one or more observed dimensions in FA.\n');
end
estParams.L = L;
estParams.Ph = Ph;
estParams.d = d;