0001 function M = fixedrankfactory_3factors_preconditioned(m, n, k)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043 M.name = @() sprintf('LSR'' (tuned for least square problems) quotient manifold of %dx%d matrices of rank %d', m, n, k);
0044
0045 M.dim = @() (m+n-k)*k;
0046
0047
0048
0049 function X = prepare(X)
0050 if ~all(isfield(X,{'StS','SSt'}) == 1)
0051 X.SSt = X.S*X.S';
0052 X.StS = X.S'*X.S;
0053 end
0054 end
0055
0056
0057
0058 M.inner = @iproduct;
0059 function ip = iproduct(X, eta, zeta)
0060 X = prepare(X);
0061
0062 ip = trace(X.SSt*(eta.L'*zeta.L)) + trace(X.StS*(eta.R'*zeta.R)) ...
0063 + trace(eta.S'*zeta.S);
0064 end
0065
0066 M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
0067
0068 M.dist = @(x, y) error('fixedrankfactory_3factors_preconditioned.dist not implemented yet.');
0069
0070 M.typicaldist = @() 10*k;
0071
0072 skew = @(X) .5*(X-X');
0073 symm = @(X) .5*(X+X');
0074
0075 M.egrad2rgrad = @egrad2rgrad;
0076 function rgrad = egrad2rgrad(X, egrad)
0077 X = prepare(X);
0078
0079 SSL = X.SSt;
0080 ASL = 2*symm(SSL*(egrad.S*X.S'));
0081
0082 SSR = X.StS;
0083 ASR = 2*symm(SSR*(egrad.S'*X.S));
0084
0085 [BL, BR] = tangent_space_lyap(X.S, ASL, ASR);
0086
0087 rgrad.L = (egrad.L - X.L*BL)/X.SSt;
0088 rgrad.R = (egrad.R - X.R*BR)/X.StS;
0089 rgrad.S = egrad.S;
0090
0091
0092
0093
0094
0095
0096
0097 end
0098
0099
0100
0101 M.ehess2rhess = @ehess2rhess;
0102 function Hess = ehess2rhess(X, egrad, ehess, eta)
0103 X = prepare(X);
0104
0105
0106 SSL = X.SSt;
0107 ASL = 2*symm(SSL*(egrad.S*X.S'));
0108 SSR = X.StS;
0109 ASR = 2*symm(SSR*(egrad.S'*X.S));
0110 [BL, BR] = tangent_space_lyap(X.S, ASL, ASR);
0111
0112 rgrad.L = (egrad.L - X.L*BL)/X.SSt;
0113 rgrad.R = (egrad.R - X.R*BR)/X.StS;
0114 rgrad.S = egrad.S;
0115
0116
0117 ASLdot = 2*symm((2*symm(X.S*eta.S')*(egrad.S*X.S')) + X.SSt*(ehess.S*X.S' + egrad.S*eta.S')) - 4*symm(symm(eta.S*X.S')*BL);
0118 ASRdot = 2*symm((2*symm(X.S'*eta.S)*(egrad.S'*X.S)) + X.StS*(ehess.S'*X.S + egrad.S'*eta.S)) - 4*symm(symm(eta.S'*X.S)*BR);
0119
0120
0121
0122
0123
0124
0125 [BLdot, BRdot] = tangent_space_lyap(X.S, ASLdot, ASRdot);
0126
0127 Hess.L = (ehess.L - eta.L*BL - X.L*BLdot - 2*rgrad.L*symm(eta.S*X.S'))/X.SSt;
0128 Hess.R = (ehess.R - eta.R*BR - X.R*BRdot - 2*rgrad.R*symm(eta.S'*X.S))/X.StS;
0129 Hess.S = ehess.S;
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139 Hess.L = Hess.L + (eta.L*symm(rgrad.S*X.S') + rgrad.L*symm(eta.S*X.S'))/X.SSt;
0140 Hess.R = Hess.R + (eta.R*symm(rgrad.S'*X.S) + rgrad.R*symm(eta.S'*X.S))/X.StS;
0141 Hess.S = Hess.S - symm(rgrad.L'*eta.L)*X.S - X.S*symm(rgrad.R'*eta.R);
0142
0143
0144
0145
0146
0147 Hess = M.proj(X, Hess);
0148
0149
0150
0151
0152
0153 end
0154
0155
0156
0157
0158 M.proj = @projection;
0159 function etaproj = projection(X, eta)
0160 X = prepare(X);
0161
0162
0163 SSL = X.SSt;
0164 ASL = 2*symm(X.SSt*(X.L'*eta.L)*X.SSt);
0165 BL = lyapunov_symmetric(SSL, ASL);
0166 eta.L = eta.L - X.L*(BL/X.SSt);
0167
0168 SSR = X.StS;
0169 ASR = 2*symm(X.StS*(X.R'*eta.R)*X.StS);
0170 BR = lyapunov_symmetric(SSR, ASR);
0171 eta.R = eta.R - X.R*(BR/X.StS);
0172
0173
0174 PU = skew((X.L'*eta.L)*X.SSt) + skew(X.S*eta.S');
0175 PV = skew((X.R'*eta.R)*X.StS) + skew(X.S'*eta.S);
0176 [Omega1, Omega2] = coupled_lyap(X.S, PU, PV);
0177
0178
0179
0180
0181 etaproj.L = eta.L - (X.L*Omega1);
0182 etaproj.S = eta.S - (X.S*Omega2 - Omega1*X.S) ;
0183 etaproj.R = eta.R - (X.R*Omega2);
0184
0185
0186
0187
0188
0189
0190
0191
0192
0193 end
0194
0195
0196 M.tangent = M.proj;
0197 M.tangent2ambient = @(X, eta) eta;
0198
0199 M.retr = @retraction;
0200 function Y = retraction(X, eta, t)
0201 if nargin < 3
0202 t = 1.0;
0203 end
0204
0205 Y.S = (X.S + t*eta.S);
0206 Y.L = uf((X.L + t*eta.L));
0207 Y.R = uf((X.R + t*eta.R));
0208
0209 Y = prepare(Y);
0210 end
0211
0212
0213 M.hash = @(X) ['z' hashmd5([X.L(:) ; X.S(:) ; X.R(:)])];
0214
0215 M.rand = @random;
0216
0217
0218 stiefelm = stiefelfactory(m, k);
0219 stiefeln = stiefelfactory(n, k);
0220 function X = random()
0221 X.L = stiefelm.rand();
0222 X.R = stiefeln.rand();
0223 X.S = diag(1+rand(k, 1));
0224
0225 X = prepare(X);
0226 end
0227
0228 M.randvec = @randomvec;
0229 function eta = randomvec(X)
0230
0231 eta.L = randn(m, k);
0232 eta.R = randn(n, k);
0233 eta.S = randn(k, k);
0234 eta = projection(X, eta);
0235 nrm = M.norm(X, eta);
0236 eta.L = eta.L / nrm;
0237 eta.R = eta.R / nrm;
0238 eta.S = eta.S / nrm;
0239 end
0240
0241 M.lincomb = @lincomb;
0242
0243 M.zerovec = @(X) struct('L', zeros(m, k), 'S', zeros(k, k), ...
0244 'R', zeros(n, k));
0245
0246 M.transp = @(x1, x2, d) projection(x2, d);
0247
0248
0249 M.vec = @(X, U) [U.L(:) ; U.S(:); U.R(:)];
0250 M.mat = @(X, u) struct('L', reshape(u(1:(m*k)), m, k), ...
0251 'S', reshape(u((m*k+1): m*k + k*k), k, k), ...
0252 'R', reshape(u((m*k+ k*k + 1):end), n, k));
0253 M.vecmatareisometries = @() false;
0254
0255 end
0256
0257
0258 function d = lincomb(x, a1, d1, a2, d2)
0259
0260 if nargin == 3
0261 d.L = a1*d1.L;
0262 d.R = a1*d1.R;
0263 d.S = a1*d1.S;
0264 elseif nargin == 5
0265 d.L = a1*d1.L + a2*d2.L;
0266 d.R = a1*d1.R + a2*d2.R;
0267 d.S = a1*d1.S + a2*d2.S;
0268 else
0269 error('Bad use of fixedrankfactory_3factors_preconditioned.lincomb.');
0270 end
0271
0272 end
0273
0274 function A = uf(A)
0275 [L, unused, R] = svd(A, 0);
0276 A = L*R';
0277 end
0278
0279 function[BU, BV] = tangent_space_lyap(R, E, F)
0280
0281
0282
0283
0284
0285
0286
0287
0288 [U, Sigma, V] = svd(R);
0289 E_mod = U'*E*U;
0290 F_mod = V'*F*V;
0291 b1 = E_mod(:);
0292 b2 = F_mod(:);
0293
0294 r = size(Sigma, 1);
0295 sig = diag(Sigma);
0296 sig1 = sig*ones(1, r);
0297 sig1t = sig1';
0298 s1 = sig1(:);
0299 s2 = sig1t(:);
0300
0301
0302 a = s1.^2 + s2.^2;
0303
0304
0305 cu = b1./a;
0306 cv = b2./a;
0307
0308
0309 CU = reshape(cu, r, r);
0310 CV = reshape(cv, r, r);
0311
0312
0313 BU = U*CU*U';
0314 BV = V*CV*V';
0315
0316
0317
0318
0319
0320
0321
0322
0323
0324
0325
0326
0327
0328
0329
0330
0331
0332
0333 end
0334
0335
0336
0337 function[Omega1, Omega2] = coupled_lyap(R, E, F)
0338
0339
0340
0341
0342
0343
0344
0345
0346
0347 [U, Sigma, V] = svd(R);
0348 E_mod = U'*E*U;
0349 F_mod = V'*F*V;
0350 b1 = E_mod(:);
0351 b2 = F_mod(:);
0352
0353 r = size(Sigma, 1);
0354 sig = diag(Sigma);
0355 sig1 = sig*ones(1, r);
0356 sig1t = sig1';
0357 s1 = sig1(:);
0358 s2 = sig1t(:);
0359
0360
0361 a = s1.^2 + s2.^2;
0362 c = s1.*s2;
0363
0364
0365
0366
0367
0368
0369
0370 Y1_sol = (b2 + (a./c).*b1) ./ ((a.^2)./c - c);
0371 Y2_sol = (b2 + c.*Y1_sol)./a;
0372
0373
0374 Omega1 = reshape(Y1_sol, r, r);
0375 Omega2 = reshape(Y2_sol, r, r);
0376
0377
0378 Omega1 = U*Omega1*U';
0379 Omega2 = V*Omega2*V';
0380
0381
0382
0383
0384 end