Home > manopt > manifolds > fixedrank > fixedrankfactory_3factors_preconditioned.m

fixedrankfactory_3factors_preconditioned

PURPOSE ^

Manifold of m-by-n matrices of rank k with three factor quotient geometry.

SYNOPSIS ^

function M = fixedrankfactory_3factors_preconditioned(m, n, k)

DESCRIPTION ^

 Manifold of m-by-n matrices of rank k with three factor quotient geometry.

 function M = fixedrankfactory_3factors_preconditioned(m, n, k)

 This geometry is tuned to least squares problems such as low-rank matrix
 completion with ell-2 loss.

 A point X on the manifold is represented as a structure with three
 fields: L, S and R. The matrices L (mxk) and R (nxk) are orthonormal,
 while the matrix S (kxk) is a full rank matrix such that X = L*S*R'.

 Tangent vectors are represented as a structure with three fields: L, S
 and R.

 Please cite the Manopt paper as well as the research paper:
     @InProceedings{mishra2014r3mc,
       Title        = {{R3MC}: A {R}iemannian three-factor algorithm for low-rank matrix completion},
       Author       = {Mishra, B. and Sepulchre, R.},
       Booktitle    = {{53rd IEEE Conference on Decision and Control}},
       Year         = {2014},
       Organization = {{IEEE CDC}}
     }


 See also: fixedrankfactory_3factors fixedrankfactory_2factors_preconditioned

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = fixedrankfactory_3factors_preconditioned(m, n, k)
0002 % Manifold of m-by-n matrices of rank k with three factor quotient geometry.
0003 %
0004 % function M = fixedrankfactory_3factors_preconditioned(m, n, k)
0005 %
0006 % This geometry is tuned to least squares problems such as low-rank matrix
0007 % completion with ell-2 loss.
0008 %
0009 % A point X on the manifold is represented as a structure with three
0010 % fields: L, S and R. The matrices L (mxk) and R (nxk) are orthonormal,
0011 % while the matrix S (kxk) is a full rank matrix such that X = L*S*R'.
0012 %
0013 % Tangent vectors are represented as a structure with three fields: L, S
0014 % and R.
0015 %
0016 % Please cite the Manopt paper as well as the research paper:
0017 %     @InProceedings{mishra2014r3mc,
0018 %       Title        = {{R3MC}: A {R}iemannian three-factor algorithm for low-rank matrix completion},
0019 %       Author       = {Mishra, B. and Sepulchre, R.},
0020 %       Booktitle    = {{53rd IEEE Conference on Decision and Control}},
0021 %       Year         = {2014},
0022 %       Organization = {{IEEE CDC}}
0023 %     }
0024 %
0025 %
0026 % See also: fixedrankfactory_3factors fixedrankfactory_2factors_preconditioned
0027 
0028 % This file is part of Manopt: www.manopt.org.
0029 % Original author: Bamdev Mishra, Dec. 30, 2012.
0030 % Contributors:
0031 % Change log:
0032 %
0033 %    Apr.  4, 2015 (BM):
0034 %        Cosmetic changes including avoiding storing the inverse of a kxk matrix.
0035 %
0036 %    Apr. 18, 2018 (NB):
0037 %        Removed lyap dependency.
0038 %
0039 %    Sep. 6, 2018 (NB):
0040 %        Removed M.exp() as it was not implemented.
0041 
0042     
0043     M.name = @() sprintf('LSR'' (tuned for least square problems) quotient manifold of %dx%d matrices of rank %d', m, n, k);
0044     
0045     M.dim = @() (m+n-k)*k;
0046     
0047     % Some precomputations at the point X that are to be used in the inner product (and
0048     % pretty much everywhere else).
0049     function X = prepare(X)
0050         if ~all(isfield(X,{'StS','SSt'}) == 1)
0051             X.SSt = X.S*X.S';
0052             X.StS = X.S'*X.S;
0053         end
0054     end
0055     
0056     % The choice of metric is motivated by symmetry and tuned to least square
0057     % objective function.
0058     M.inner = @iproduct;
0059     function ip = iproduct(X, eta, zeta)
0060         X = prepare(X);
0061         
0062         ip = trace(X.SSt*(eta.L'*zeta.L)) + trace(X.StS*(eta.R'*zeta.R)) ...
0063             + trace(eta.S'*zeta.S);
0064     end
0065     
0066     M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
0067     
0068     M.dist = @(x, y) error('fixedrankfactory_3factors_preconditioned.dist not implemented yet.');
0069     
0070     M.typicaldist = @() 10*k;
0071     
0072     skew = @(X) .5*(X-X');
0073     symm = @(X) .5*(X+X');
0074     
0075     M.egrad2rgrad = @egrad2rgrad;
0076     function rgrad = egrad2rgrad(X, egrad)
0077         X = prepare(X);
0078         
0079         SSL = X.SSt;
0080         ASL = 2*symm(SSL*(egrad.S*X.S'));
0081         
0082         SSR = X.StS;
0083         ASR = 2*symm(SSR*(egrad.S'*X.S));
0084         
0085         [BL, BR] = tangent_space_lyap(X.S, ASL, ASR); % It computes the solution without calling Matlab's Lyap.
0086         
0087         rgrad.L = (egrad.L - X.L*BL)/X.SSt;
0088         rgrad.R = (egrad.R - X.R*BR)/X.StS;
0089         rgrad.S = egrad.S;
0090         
0091         % Debug
0092         %         BL1 = lyap(SSL, -ASL); % Alternate way
0093         %         BR1 = lyap(SSR, -ASR);
0094         %         norm(skew(X.SSt*(rgrad.L'*X.L) + rgrad.S*X.S'), 'fro')
0095         %         norm(skew(X.StS*(rgrad.R'*X.R) - X.S'*rgrad.S), 'fro')
0096         
0097     end
0098     
0099     
0100     
0101     M.ehess2rhess = @ehess2rhess;
0102     function Hess = ehess2rhess(X, egrad, ehess, eta)
0103         X = prepare(X);
0104         
0105         % Riemannian gradient.
0106         SSL = X.SSt;
0107         ASL = 2*symm(SSL*(egrad.S*X.S'));
0108         SSR = X.StS;
0109         ASR = 2*symm(SSR*(egrad.S'*X.S));
0110         [BL, BR] = tangent_space_lyap(X.S, ASL, ASR);
0111         
0112         rgrad.L = (egrad.L - X.L*BL)/X.SSt;
0113         rgrad.R = (egrad.R - X.R*BR)/X.StS;
0114         rgrad.S = egrad.S;
0115         
0116         % Directional derivative of the Riemannian gradient.
0117         ASLdot = 2*symm((2*symm(X.S*eta.S')*(egrad.S*X.S')) + X.SSt*(ehess.S*X.S' + egrad.S*eta.S')) - 4*symm(symm(eta.S*X.S')*BL);
0118         ASRdot = 2*symm((2*symm(X.S'*eta.S)*(egrad.S'*X.S)) + X.StS*(ehess.S'*X.S + egrad.S'*eta.S)) - 4*symm(symm(eta.S'*X.S)*BR);
0119         
0120         %         SSLdot = X.SSt;
0121         %         SSRdot = X.StS;
0122         %         BLdot = lyap(SSLdot, -ASLdot);
0123         %         BRdot = lyap(SSRdot, -ASRdot);
0124         
0125         [BLdot, BRdot] = tangent_space_lyap(X.S, ASLdot, ASRdot);
0126         
0127         Hess.L = (ehess.L - eta.L*BL - X.L*BLdot - 2*rgrad.L*symm(eta.S*X.S'))/X.SSt;
0128         Hess.R = (ehess.R - eta.R*BR - X.R*BRdot - 2*rgrad.R*symm(eta.S'*X.S))/X.StS;
0129         Hess.S = ehess.S;
0130         
0131         
0132         
0133         % BM: Till this, everything seems correct.
0134         % We still need a correction factor for the non-constant metric
0135         % that is imposed.
0136         % The computation of the correction factor owes itself to the Koszul formula.
0137         % This corresponds to the Riemannian connection in the Euclidean space with the
0138         % scaled metric.
0139         Hess.L = Hess.L + (eta.L*symm(rgrad.S*X.S') + rgrad.L*symm(eta.S*X.S'))/X.SSt;
0140         Hess.R = Hess.R + (eta.R*symm(rgrad.S'*X.S) + rgrad.R*symm(eta.S'*X.S))/X.StS;
0141         Hess.S = Hess.S - symm(rgrad.L'*eta.L)*X.S - X.S*symm(rgrad.R'*eta.R);
0142         
0143         % The Riemannian connection on the quotient space is the
0144         % projection of the Riemannian connection in the ambient space onto the tangent space of the total space and
0145         % then onto the horizontal space.
0146         % This is accomplished by the following operation.
0147         Hess = M.proj(X, Hess);
0148         
0149         % Debug
0150         %         norm(skew(X.SSt*(Hess.L'*X.L) + Hess.S*X.S'))
0151         %         norm(skew(X.StS*(Hess.R'*X.R) - X.S'*Hess.S))
0152         
0153     end
0154     
0155     
0156     
0157     
0158     M.proj = @projection;
0159     function etaproj = projection(X, eta)
0160         X = prepare(X);
0161         
0162         % First, projection onto the tangent space of the total space.
0163         SSL = X.SSt;
0164         ASL = 2*symm(X.SSt*(X.L'*eta.L)*X.SSt);
0165         BL = lyapunov_symmetric(SSL, ASL);
0166         eta.L = eta.L - X.L*(BL/X.SSt);
0167         
0168         SSR = X.StS;
0169         ASR = 2*symm(X.StS*(X.R'*eta.R)*X.StS);
0170         BR = lyapunov_symmetric(SSR, ASR);
0171         eta.R = eta.R - X.R*(BR/X.StS);
0172         
0173         % Project onto the horizontal space
0174         PU = skew((X.L'*eta.L)*X.SSt) + skew(X.S*eta.S');
0175         PV = skew((X.R'*eta.R)*X.StS)  + skew(X.S'*eta.S);
0176         [Omega1, Omega2] = coupled_lyap(X.S, PU, PV);
0177         %         norm(2*skew(Omega1*X.SSt) - PU -(X.S*Omega2*X.S'),'fro' )
0178         %         norm(2*skew(Omega2*X.StS) - PV -(X.S'*Omega1*X.S),'fro' )
0179         %
0180         
0181         etaproj.L = eta.L - (X.L*Omega1);
0182         etaproj.S = eta.S - (X.S*Omega2 - Omega1*X.S) ;
0183         etaproj.R = eta.R - (X.R*Omega2);
0184         
0185         
0186         % Debug
0187         %         norm(skew(X.SSt*(etaproj.L'*X.L) + etaproj.S*X.S'))
0188         %         norm(skew(X.StS*(etaproj.R'*X.R) - X.S'*etaproj.S))
0189         %
0190         %         norm(skew(X.SSt*(etaproj.L'*X.L) - X.S*etaproj.S'))
0191         %         norm(skew(X.StS*(etaproj.R'*X.R) + etaproj.S'*X.S))
0192         
0193     end
0194     
0195     
0196     M.tangent = M.proj;
0197     M.tangent2ambient = @(X, eta) eta;
0198     
0199     M.retr = @retraction;
0200     function Y = retraction(X, eta, t)
0201         if nargin < 3
0202             t = 1.0;
0203         end
0204         
0205         Y.S = (X.S + t*eta.S);
0206         Y.L = uf((X.L + t*eta.L));
0207         Y.R = uf((X.R + t*eta.R));
0208         
0209         Y = prepare(Y);
0210     end
0211     
0212     
0213     M.hash = @(X) ['z' hashmd5([X.L(:) ; X.S(:) ; X.R(:)])];
0214     
0215     M.rand = @random;
0216     % Factors L and R live on Stiefel manifolds, hence we will reuse
0217     % their random generator.
0218     stiefelm = stiefelfactory(m, k);
0219     stiefeln = stiefelfactory(n, k);
0220     function X = random()
0221         X.L = stiefelm.rand();
0222         X.R = stiefeln.rand();
0223         X.S = diag(1+rand(k, 1));
0224         
0225         X = prepare(X);
0226     end
0227     
0228     M.randvec = @randomvec;
0229     function eta = randomvec(X)
0230         % A random vector on the horizontal space
0231         eta.L = randn(m, k);
0232         eta.R = randn(n, k);
0233         eta.S = randn(k, k);
0234         eta = projection(X, eta);
0235         nrm = M.norm(X, eta);
0236         eta.L = eta.L / nrm;
0237         eta.R = eta.R / nrm;
0238         eta.S = eta.S / nrm;
0239     end
0240     
0241     M.lincomb = @lincomb;
0242     
0243     M.zerovec = @(X) struct('L', zeros(m, k), 'S', zeros(k, k), ...
0244         'R', zeros(n, k));
0245     
0246     M.transp = @(x1, x2, d) projection(x2, d);
0247     
0248     % vec and mat are not isometries, because of the unusual inner metric.
0249     M.vec = @(X, U) [U.L(:) ; U.S(:); U.R(:)];
0250     M.mat = @(X, u) struct('L', reshape(u(1:(m*k)), m, k), ...
0251         'S', reshape(u((m*k+1): m*k + k*k), k, k), ...
0252         'R', reshape(u((m*k+ k*k + 1):end), n, k));
0253     M.vecmatareisometries = @() false;
0254     
0255 end
0256 
0257 % Linear combination of tangent vectors
0258 function d = lincomb(x, a1, d1, a2, d2) %#ok<INUSL>
0259     
0260     if nargin == 3
0261         d.L = a1*d1.L;
0262         d.R = a1*d1.R;
0263         d.S = a1*d1.S;
0264     elseif nargin == 5
0265         d.L = a1*d1.L + a2*d2.L;
0266         d.R = a1*d1.R + a2*d2.R;
0267         d.S = a1*d1.S + a2*d2.S;
0268     else
0269         error('Bad use of fixedrankfactory_3factors_preconditioned.lincomb.');
0270     end
0271     
0272 end
0273 
0274 function A = uf(A)
0275     [L, unused, R] = svd(A, 0); %#ok
0276     A = L*R';
0277 end
0278 
0279 function[BU, BV] = tangent_space_lyap(R, E, F)
0280     % We intent to solve a linear system    RR^T  BU + BU RR^T  = E
0281     %                                       R^T R BV + BV R^T R = F
0282     % for BU and BV.
0283     %
0284     % This can be solved using two calls to the Matlab's lyap.
0285     % However, we can still have a more efficient implementation
0286     % that does not require the full functionality of Matlab's lyap.
0287     
0288     [U, Sigma, V] = svd(R);
0289     E_mod = U'*E*U;
0290     F_mod = V'*F*V;
0291     b1 = E_mod(:);
0292     b2 = F_mod(:);
0293     
0294     r = size(Sigma, 1);
0295     sig = diag(Sigma); % all the singular values in a vector
0296     sig1 = sig*ones(1, r); % columns repeat
0297     sig1t = sig1'; % rows repeat
0298     s1 = sig1(:);
0299     s2 = sig1t(:);
0300     
0301     % The block elements
0302     a =  s1.^2 + s2.^2; % a column vector
0303     
0304     % Solve the linear system of equations
0305     cu = b1./a; %a.\b1;
0306     cv = b2./a; %a.\b2;
0307     
0308     % Matricize
0309     CU = reshape(cu, r, r);
0310     CV = reshape(cv, r, r);
0311     
0312     % Do the similarity transforms
0313     BU = U*CU*U';
0314     BV = V*CV*V';
0315     
0316     % %% Debug
0317     %
0318     % norm(R*R'*BU + BU*R*R' - E, 'fro');
0319     % norm((Sigma.^2)*CU + CU*(Sigma.^2) - E_mod, 'fro');
0320     % norm(a.*cu - b1, 'fro');
0321     %
0322     % norm(R'*R*BV + BV*R'*R - F, 'fro');
0323     %
0324     % BU1 = lyap(R*R', - E);
0325     % norm(R*R'*BU1 + BU1*R*R' - E, 'fro');
0326     %
0327     % BV1 = lyap(R'*R, - F);
0328     % norm(R'*R*BV1 + BV1*R'*R - F, 'fro');
0329     %
0330     % % as accurate as the lyap
0331     % norm(BU - BU1, 'fro')
0332     % norm(BV - BV1, 'fro')
0333 end
0334 
0335 
0336 
0337 function[Omega1, Omega2] = coupled_lyap(R, E, F)
0338     % We intent to solve the coupled system of Lyapunov equations
0339     %
0340     % RR^T Omega1 + Omega1 RR^T  - R Omega2 R^T = E
0341     % R^T R Omega2 + Omega1 R^T R  - R^T Omega2 R = F,
0342     %
0343     % for Omega1 and Omega2, both are skew symmetric matrices.
0344     %
0345     % Below is an efficient implementation
0346     
0347     [U, Sigma, V] = svd(R);
0348     E_mod = U'*E*U;
0349     F_mod = V'*F*V;
0350     b1 = E_mod(:);
0351     b2 = F_mod(:);
0352     
0353     r = size(Sigma, 1);
0354     sig = diag(Sigma); % All the singular values in a vector
0355     sig1 = sig*ones(1, r); % Columns repeat
0356     sig1t = sig1'; % Rows repeat
0357     s1 = sig1(:);
0358     s2 = sig1t(:);
0359     
0360     % The block elements
0361     a =  s1.^2 + s2.^2; % A column vector
0362     c = s1.*s2;
0363     
0364     % Solve directly using the formula
0365     % A = diag(a);
0366     % C = diag(c);
0367     % Y1_sol = (A*(C\A) - C) \ (b2 + A*(C\b1));
0368     % Y2_sol = A\(b2 + C*Y1_sol);
0369     
0370     Y1_sol = (b2 + (a./c).*b1) ./ ((a.^2)./c - c);
0371     Y2_sol = (b2 + c.*Y1_sol)./a;
0372     
0373     % Matricize
0374     Omega1 = reshape(Y1_sol, r, r);
0375     Omega2 = reshape(Y2_sol, r, r);
0376     
0377     % Do the similarity transforms
0378     Omega1 = U*Omega1*U';
0379     Omega2 = V*Omega2*V';
0380     
0381     % %% Debug: whether we have the right solution.
0382     % norm(R*R'*Omega1 + Omega1*R*R'  - R*Omega2*R' - E, 'fro')
0383     % norm(R'*R*Omega2 + Omega2*R'*R  - R'*Omega1*R - F, 'fro')
0384 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005