0001 function M = fixedrankfactory_tucker_preconditioned(tensor_size, tensor_rank)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057 if length(tensor_rank) > 3
0058 error('Bad usage of fixedrankfactory_tucker_preconditioned. Currently, only handles 3-order tensors.');
0059 end
0060
0061
0062 n1 = tensor_size(1);
0063 n2 = tensor_size(2);
0064 n3 = tensor_size(3);
0065
0066
0067 r1 = tensor_rank(1);
0068 r2 = tensor_rank(2);
0069 r3 = tensor_rank(3);
0070
0071
0072 speyer1 = speye(r1);
0073 speyer2 = speye(r2);
0074 speyer3 = speye(r3);
0075
0076 M = fixedrankfactory_tucker_preconditioned_helper(...
0077 tensor_size, tensor_rank, ...
0078 n1, n2, n3, r1, r2, r3, speyer1, speyer2, speyer3);
0079
0080 end
0081
0082
0083
0084 function M = fixedrankfactory_tucker_preconditioned_helper(...
0085 tensor_size, tensor_rank, ...
0086 n1, n2, n3, r1, r2, r3, speyer1, speyer2, speyer3)
0087
0088 M.name = @() sprintf('G x U1 x U2 x U3 quotient Tucker manifold of %d-by-%d-by-%d tensor of rank %d-by-%d-by-%d.', n1, n2, n3, r1, r2, r3);
0089
0090 M.dim = @() n1*r1-r1^2 + n2*r2-r2^2 + n3*r3-r3^2 + r1*r2*r3;
0091
0092
0093
0094 function X = prepare(X)
0095 if ~all(isfield(X,{'G1G1t','G1',...
0096 'G2G2t','G2', ...
0097 'G3G3t','G3'}) == 1)
0098
0099 X.G1 = reshape(X.G, r1, r2*r3);
0100 X.G1G1t = X.G1*X.G1';
0101
0102
0103 X.G2 = reshape(permute(X.G, [2 1 3]), r2, r1*r3);
0104 X.G2G2t = X.G2*X.G2';
0105
0106
0107 X.G3 = reshape(permute(X.G, [3 1 2]), r3, r1*r2);
0108 X.G3G3t = X.G3*X.G3';
0109 end
0110
0111
0112 end
0113
0114
0115
0116 M.inner = @iproduct;
0117 function ip = iproduct(X, eta, zeta)
0118 X = prepare(X);
0119 ip = trace(X.G1G1t*(eta.U1'*zeta.U1)) ...
0120 + trace(X.G2G2t*(eta.U2'*zeta.U2)) ...
0121 + trace(X.G3G3t*(eta.U3'*zeta.U3)) ...
0122 + (eta.G(:)'*zeta.G(:));
0123 end
0124 M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
0125
0126 M.dist = @(x, y) error('fixedrankfactory_tucker_preconditioned.dist not implemented yet.');
0127
0128 M.typicaldist = @() 10*n1*r1;
0129
0130 M.egrad2rgrad = @egrad2rgrad;
0131 function rgrad = egrad2rgrad(X, egrad)
0132 X = prepare(X);
0133
0134 SSU1 = X.G1G1t;
0135 ASU1 = 2*symm(SSU1*(X.U1' * egrad.U1));
0136
0137 SSU2 = X.G2G2t;
0138 ASU2 = 2*symm(SSU2*(X.U2' * egrad.U2));
0139
0140 SSU3 = X.G3G3t;
0141 ASU3 = 2*symm(SSU3*(X.U3' * egrad.U3));
0142
0143
0144 BU1 = lyapunov_symmetric(SSU1, ASU1);
0145 BU2 = lyapunov_symmetric(SSU2, ASU2);
0146 BU3 = lyapunov_symmetric(SSU3, ASU3);
0147
0148
0149
0150
0151
0152
0153 rgrad.U1 = (egrad.U1 - X.U1*BU1)/X.G1G1t;
0154 rgrad.U2 = (egrad.U2 - X.U2*BU2)/X.G2G2t;
0155 rgrad.U3 = (egrad.U3 - X.U3*BU3)/X.G3G3t;
0156 rgrad.G = egrad.G;
0157
0158
0159 end
0160
0161
0162
0163 M.ehess2rhess = @ehess2rhess;
0164 function Hess = ehess2rhess(X, egrad, ehess, eta)
0165 X = prepare(X);
0166
0167
0168 SSU1 = X.G1G1t;
0169 ASU1 = 2*symm(SSU1*(X.U1' * egrad.U1));
0170 SSU2 = X.G2G2t;
0171 ASU2 = 2*symm(SSU2*(X.U2' * egrad.U2));
0172 SSU3 = X.G3G3t;
0173 ASU3 = 2*symm(SSU3*(X.U3' * egrad.U3));
0174
0175 BU1 = lyapunov_symmetric(SSU1, ASU1);
0176 BU2 = lyapunov_symmetric(SSU2, ASU2);
0177 BU3 = lyapunov_symmetric(SSU3, ASU3);
0178
0179 rgrad.U1 = (egrad.U1 - X.U1*BU1)/X.G1G1t;
0180 rgrad.U2 = (egrad.U2 - X.U2*BU2)/X.G2G2t;
0181 rgrad.U3 = (egrad.U3 - X.U3*BU3)/X.G3G3t;
0182 rgrad.G = egrad.G;
0183
0184
0185
0186 eta_G1 = reshape(eta.G, r1, r2*r3);
0187 eta_G2 = reshape(permute(eta.G, [2 1 3]), r2, r1*r3);
0188 eta_G3 = reshape(permute(eta.G, [3 1 2]), r3, r1*r2);
0189 egrad_G1 = reshape(egrad.G, r1, r2*r3);
0190 egrad_G2 = reshape(permute(egrad.G, [2 1 3]), r2, r1*r3);
0191 egrad_G3 = reshape(permute(egrad.G, [3 1 2]), r3, r1*r2);
0192 ehess_G1 = reshape(ehess.G, r1, r2*r3);
0193 ehess_G2 = reshape(permute(ehess.G, [2 1 3]), r2, r1*r3);
0194 ehess_G3 = reshape(permute(ehess.G, [3 1 2]), r3, r1*r2);
0195 rgrad_G1 = reshape(rgrad.G, r1, r2*r3);
0196 rgrad_G2 = reshape(permute(rgrad.G, [2 1 3]), r2, r1*r3);
0197 rgrad_G3 = reshape(permute(rgrad.G, [3 1 2]), r3, r1*r2);
0198
0199 ASU1dot = 2*symm((2*symm(X.G1*eta_G1')*(egrad_G1*X.G1')) + X.G1G1t*(ehess_G1*X.G1' + egrad_G1*eta_G1')) - 4*symm(symm(eta_G1*X.G1')*BU1);
0200 ASU2dot = 2*symm((2*symm(X.G2*eta_G2')*(egrad_G2*X.G2')) + X.G2G2t*(ehess_G2*X.G2' + egrad_G2*eta_G2')) - 4*symm(symm(eta_G2*X.G2')*BU2);
0201 ASU3dot = 2*symm((2*symm(X.G3*eta_G3')*(egrad_G3*X.G3')) + X.G3G3t*(ehess_G3*X.G3' + egrad_G3*eta_G3')) - 4*symm(symm(eta_G3*X.G3')*BU3);
0202
0203
0204 SSU1dot = X.G1G1t;
0205 SSU2dot = X.G2G2t;
0206 SSU3dot = X.G3G3t;
0207 BU1dot = lyapunov_symmetric(SSU1dot, ASU1dot);
0208 BU2dot = lyapunov_symmetric(SSU2dot, ASU2dot);
0209 BU3dot = lyapunov_symmetric(SSU3dot, ASU3dot);
0210
0211
0212 Hess.U1 = (ehess.U1 - eta.U1*BU1 - X.U1*BU1dot - 2*rgrad.U1*symm(eta_G1*X.G1'))/X.G1G1t;
0213 Hess.U2 = (ehess.U2 - eta.U2*BU2 - X.U2*BU2dot - 2*rgrad.U2*symm(eta_G2*X.G2'))/X.G2G2t;
0214 Hess.U3 = (ehess.U3 - eta.U3*BU3 - X.U3*BU3dot - 2*rgrad.U3*symm(eta_G3*X.G3'))/X.G3G3t;
0215 Hess.G = ehess.G;
0216
0217
0218
0219
0220
0221
0222
0223
0224
0225 Hess.U1 = Hess.U1 + (eta.U1*symm(rgrad_G1*X.G1') + rgrad.U1*symm(eta_G1*X.G1'))/X.G1G1t;
0226 Hess.U2 = Hess.U2 + (eta.U2*symm(rgrad_G2*X.G2') + rgrad.U2*symm(eta_G2*X.G2'))/X.G2G2t;
0227 Hess.U3 = Hess.U3 + (eta.U3*symm(rgrad_G3*X.G3') + rgrad.U3*symm(eta_G3*X.G3'))/X.G3G3t;
0228 Hess.G = Hess.G - permute(reshape(symm(rgrad.U1'*eta.U1)*X.G1,r1,r2,r3), [1 2 3]) ...
0229 - permute(reshape(symm(rgrad.U2'*eta.U2)*X.G2,r2,r1,r3), [2 1 3]) ...
0230 - permute(reshape(symm(rgrad.U3'*eta.U3)*X.G3,r3,r1,r2), [2 3 1]);
0231
0232
0233
0234
0235
0236 Hess = M.proj(X, Hess);
0237
0238
0239 end
0240
0241
0242
0243
0244 M.proj = @projection;
0245 function etaproj = projection(X, eta)
0246 X = prepare(X);
0247
0248
0249 SSU1 = X.G1G1t;
0250 ASU1 = 2*symm(X.G1G1t*(X.U1'*eta.U1)*X.G1G1t);
0251 BU1 = lyapunov_symmetric(SSU1, ASU1);
0252 eta.U1 = eta.U1 - X.U1*(BU1/X.G1G1t);
0253
0254 SSU2 = X.G2G2t;
0255 ASU2 = 2*symm(X.G2G2t*(X.U2'*eta.U2)*X.G2G2t);
0256 BU2 = lyapunov_symmetric(SSU2, ASU2);
0257 eta.U2 = eta.U2 - X.U2*(BU2/X.G2G2t);
0258
0259 SSU3 = X.G3G3t;
0260 ASU3 = 2*symm(X.G3G3t*(X.U3'*eta.U3)*X.G3G3t);
0261 BU3 = lyapunov_symmetric(SSU3, ASU3);
0262 eta.U3 = eta.U3 - X.U3*(BU3/X.G3G3t);
0263
0264 eta_G1 = reshape(eta.G, r1, r2*r3);
0265 eta_G2 = reshape(permute(eta.G, [2 1 3]), r2, r1*r3);
0266 eta_G3 = reshape(permute(eta.G, [3 1 2]), r3, r1*r2);
0267
0268
0269
0270 PU1 = skew((X.U1'*eta.U1)*X.G1G1t) + skew(X.G1*eta_G1');
0271 PU2 = skew((X.U2'*eta.U2)*X.G2G2t) + skew(X.G2*eta_G2');
0272 PU3 = skew((X.U3'*eta.U3)*X.G3G3t) + skew(X.G3*eta_G3');
0273
0274
0275
0276
0277
0278
0279
0280
0281
0282 tol_omegax_pcg = 1e-6;
0283 max_iterations_pcg = 15;
0284
0285
0286 M1 = kron(speyer1,SSU1) + kron(SSU1, speyer1);
0287 M2 = kron(speyer2,SSU2) + kron(SSU2, speyer2);
0288 M3 = kron(speyer3,SSU3) + kron(SSU3, speyer3);
0289
0290 Mprecon_pcg = sparse(zeros(r1^2 + r2^2 + r3^2));
0291 Mprecon_pcg(1 : r1^2, 1 : r1^2 ) = M1;
0292 Mprecon_pcg(1 + r1^2 : r1^2 + r2^2, 1 + r1^2 : r1^2 + r2^2) = M2;
0293 Mprecon_pcg(1 + r1^2 + r2^2 : end, 1 + r1^2 + r2^2 : end) = M3;
0294
0295
0296 [Omegaxsol, unused] = pcg(@compute_residual, [PU1(:); PU2(:); PU3(:)], tol_omegax_pcg, max_iterations_pcg, Mprecon_pcg);
0297
0298 Omega1 = reshape(Omegaxsol(1:r1^2), r1, r1);
0299 Omega2 = reshape(Omegaxsol(1 + r1^2 : r1^2 + r2^2), r2, r2);
0300 Omega3 = reshape(Omegaxsol(1 + r1^2 + r2^2 : end), r3, r3);
0301
0302 function AOmegax = compute_residual(Omegax)
0303 Omegax1 = reshape(Omegax(1:r1^2), r1, r1);
0304 Omegax2 = reshape(Omegax(1 + r1^2 : r1^2 + r2^2), r2, r2);
0305 Omegax3 = reshape(Omegax(1 + r1^2 + r2^2 : end), r3, r3);
0306
0307 OffsetU1 = X.G1*((kron(speyer3,Omegax2) + kron(Omegax3, speyer2))*X.G1');
0308 OffsetU2 = X.G2*((kron(speyer3,Omegax1) + kron(Omegax3, speyer1))*X.G2');
0309 OffsetU3 = X.G3*((kron(speyer2,Omegax1) + kron(Omegax2, speyer1))*X.G3');
0310
0311 residual1 = Omegax1*SSU1 + SSU1*Omegax1 - OffsetU1;
0312 residual2 = Omegax2*SSU2 + SSU2*Omegax2 - OffsetU2;
0313 residual3 = Omegax3*SSU3 + SSU3*Omegax3 - OffsetU3;
0314
0315 AOmegax = [residual1(:); residual2(:); residual3(:)];
0316 end
0317
0318
0319
0320 etaproj.U1 = eta.U1 - (X.U1*Omega1);
0321 etaproj.U2 = eta.U2 - (X.U2*Omega2);
0322 etaproj.U3 = eta.U3 - (X.U3*Omega3);
0323
0324
0325 GOmega1 = reshape(Omega1*X.G1, r1, r2, r3);
0326 GOmega2 = permute(reshape(Omega2*X.G2, r2, r1, r3), [2 1 3]);
0327 GOmega3 = permute(reshape(Omega3*X.G3, r3, r1, r2), [2 3 1]);
0328 etaproj.G = eta.G -(-(GOmega1+GOmega2+GOmega3));
0329
0330 end
0331
0332
0333
0334 M.tangent = M.proj;
0335 M.tangent2ambient = @(X, eta) eta;
0336
0337 M.retr = @retraction;
0338 function Y = retraction(X, eta, t)
0339 if nargin < 3
0340 t = 1.0;
0341 end
0342
0343 Y.G = (X.G + t*eta.G);
0344 Y.U1 = uf((X.U1 + t*eta.U1));
0345 Y.U2 = uf((X.U2 + t*eta.U2));
0346 Y.U3 = uf((X.U3 + t*eta.U3));
0347
0348 Y = prepare(Y);
0349 end
0350
0351
0352 M.hash = @(X) ['z' hashmd5([sum(X.U1(:)) ; sum(X.U2(:)); sum(X.U3(:)); sum(X.G(:)) ])];
0353
0354
0355 M.rand = @random;
0356 function X = random()
0357
0358
0359
0360
0361
0362
0363
0364
0365
0366
0367
0368
0369
0370
0371
0372
0373
0374 [U1, R1] = qr(rand(n1, r1), 0);
0375 [U2, R2] = qr(rand(n2, r2), 0);
0376 [U3, R3] = qr(rand(n3, r3), 0);
0377 C = rand(r1, r2, r3);
0378
0379 C1 = reshape(C, r1, r2*r3);
0380 CR1 = reshape(R1*C1, r1, r2, r3);
0381
0382 C2 = reshape(permute(CR1, [2 1 3]), r2, r1*r3);
0383 CR1R2 = permute(reshape(R2*C2, r2, r1, r3), [2 1 3]);
0384
0385 C3 = reshape(permute(CR1R2, [3 1 2]), r3, r1*r2);
0386 CR1R2R3 = permute(reshape(R3*C3, r3, r1, r2), [2 3 1]);
0387
0388 X.U1 = U1;
0389 X.U2 = U2;
0390 X.U3 = U3;
0391 X.G = CR1R2R3;
0392
0393
0394
0395 X = prepare(X);
0396
0397 end
0398
0399 M.randvec = @randomvec;
0400 function eta = randomvec(X)
0401
0402 eta.U1 = randn(n1, r1);
0403 eta.U2 = randn(n2, r2);
0404 eta.U3 = randn(n3, r3);
0405 eta.G = randn(r1, r2, r3);
0406 eta = projection(X, eta);
0407 nrm = M.norm(X, eta);
0408 eta.U1 = eta.U1 / nrm;
0409 eta.U2 = eta.U2 / nrm;
0410 eta.U3 = eta.U3 / nrm;
0411 eta.G = eta.G / nrm;
0412 end
0413
0414 M.lincomb = @lincomb;
0415
0416 M.zerovec = @(X) struct('U1', zeros(n1, r1), 'U2', zeros(n2, r2), ...
0417 'U3', zeros(n3, r3), 'G', zeros(r1, r2, r3));
0418
0419 M.transp = @transp;
0420 function v = transp(x1, x2, d)
0421 v = projection(x2, d);
0422 end
0423
0424
0425 M.vec = @(X, U1) [U1.U1(:); U1.U2(:); U1.U3(:); U1.G(:)];
0426 M.mat = @(X, u) struct ...
0427 ('U1', reshape(u(1 : n1*r1), n1, r1), ...
0428 'U2', reshape(u(n1*r1 + 1 : n1*r1 + n2*r2), n2, r2), ...
0429 'U3', reshape(u(n1*r1 + n2*r2 + 1 : n1*r1 + n2*r2 + n3*r3), n3, r3), ...
0430 'G', reshape(u(n1*r1 + n2*r2 + n3*r3 + 1 : end), r1, r2, r3));
0431 M.vecmatareisometries = @() false;
0432
0433 end
0434
0435
0436 function d = lincomb(X, a1, d1, a2, d2)
0437
0438 if nargin == 3
0439 d.U1 = a1*d1.U1;
0440 d.U2 = a1*d1.U2;
0441 d.U3 = a1*d1.U3;
0442 d.G = a1*d1.G;
0443 elseif nargin == 5
0444 d.U1 = a1*d1.U1 + a2*d2.U1;
0445 d.U2 = a1*d1.U2 + a2*d2.U2;
0446 d.U3 = a1*d1.U3 + a2*d2.U3;
0447 d.G = a1*d1.G + a2*d2.G;
0448 else
0449 error('Bad use of fixedrankfactory_tucker_preconditioned.lincomb.');
0450 end
0451
0452 end
0453
0454
0455 function U = uf(A)
0456 [L, unused, R] = svd(A, 0);
0457 U = L*R';
0458 end
0459
0460 function A = symm(Z)
0461 A = .5*(Z+Z');
0462 end
0463
0464 function A = skew(Z)
0465 A = .5*(Z-Z');
0466 end