Home > manopt > manifolds > symfixedrank > sympositivedefinitesimplexcomplexfactory.m

sympositivedefinitesimplexcomplexfactory

PURPOSE ^

Manifold of k product of n-by-n Hermitian positive definite matrices

SYNOPSIS ^

function M = sympositivedefinitesimplexcomplexfactory(n, k)

DESCRIPTION ^

 Manifold of k product of n-by-n Hermitian positive definite matrices
 with the bi-invariant geometry such that the sum is the identity matrix.

 function M = sympositivedefinitesimplexcomplexfactory(n, k)

 Given X1, X2, ... Xk Hermitian positive definite matrices, the constraint
 tackled is
 X1 + X2 + ... = I.

 The Riemannian structure enforced on the manifold 
 M:={(X1, X2,...) : X1 + X2 + ... = I } is a submanifold structure of the 
 total space defined as the k Cartesian product of Hermitian positive 
 definite Riemannian manifold (of n-by-n matrices) endowed with the bi-invariant metric.

 A point X on the manifold is represented as multidimensional array
 of size n-by-n-by-k. Each n-by-n matrix is Hermitian positive definite.
 Tangent vectors are represented as n-by-n-by-k multidimensional arrays, where
 each n-by-n matrix is Hermitian.

 The embedding space is the k Cartesian product of complex matrices of size
 n-by-n (Hermitian not required). The Euclidean gradient and Hessian expressions 
 needed for egrad2rgrad and ehess2rhess are in the embedding space endowed with the 
 usual metric for the complex plane identified with R^2.

 E = (C^(nxn))^k is the embedding space: we have the obvious representation of points 
 there as 3D arrays of size nxnxk. It is equipped with the standard Euclidean metric.

 P = {X in C^(nxn) : X = X' and X positive definite} is a submanifold of C^(nxn). 
 We turn it into a Riemannian manifold (but not a Riemannian submanifold) by equipping
 it with the bi-invariant metric.

 M = {X in P^k : X_1 + ... + X_k = I} is the manifold we care about here: it is 
 a Riemannian submanifold of P^k, hence it is also a submanifold (but not a Riemannian
 submanifold) of E -- our embedding space.


 Please cite the Manopt paper as well as the research paper:

     @techreport{mishra2019riemannian,
       title={Riemannian optimization on the simplex of positive definite matrices},
       author={Mishra, B. and Kasai, H. and Jawanpuria, P.},
       institution={arXiv preprint arXiv:1906.10436},
       year={2019}
     }

 See also sympositivedefinitesimplexcomplexfactory multinomialfactory sympositivedefinitefactory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = sympositivedefinitesimplexcomplexfactory(n, k)
0002 % Manifold of k product of n-by-n Hermitian positive definite matrices
0003 % with the bi-invariant geometry such that the sum is the identity matrix.
0004 %
0005 % function M = sympositivedefinitesimplexcomplexfactory(n, k)
0006 %
0007 % Given X1, X2, ... Xk Hermitian positive definite matrices, the constraint
0008 % tackled is
0009 % X1 + X2 + ... = I.
0010 %
0011 % The Riemannian structure enforced on the manifold
0012 % M:={(X1, X2,...) : X1 + X2 + ... = I } is a submanifold structure of the
0013 % total space defined as the k Cartesian product of Hermitian positive
0014 % definite Riemannian manifold (of n-by-n matrices) endowed with the bi-invariant metric.
0015 %
0016 % A point X on the manifold is represented as multidimensional array
0017 % of size n-by-n-by-k. Each n-by-n matrix is Hermitian positive definite.
0018 % Tangent vectors are represented as n-by-n-by-k multidimensional arrays, where
0019 % each n-by-n matrix is Hermitian.
0020 %
0021 % The embedding space is the k Cartesian product of complex matrices of size
0022 % n-by-n (Hermitian not required). The Euclidean gradient and Hessian expressions
0023 % needed for egrad2rgrad and ehess2rhess are in the embedding space endowed with the
0024 % usual metric for the complex plane identified with R^2.
0025 %
0026 % E = (C^(nxn))^k is the embedding space: we have the obvious representation of points
0027 % there as 3D arrays of size nxnxk. It is equipped with the standard Euclidean metric.
0028 %
0029 % P = {X in C^(nxn) : X = X' and X positive definite} is a submanifold of C^(nxn).
0030 % We turn it into a Riemannian manifold (but not a Riemannian submanifold) by equipping
0031 % it with the bi-invariant metric.
0032 %
0033 % M = {X in P^k : X_1 + ... + X_k = I} is the manifold we care about here: it is
0034 % a Riemannian submanifold of P^k, hence it is also a submanifold (but not a Riemannian
0035 % submanifold) of E -- our embedding space.
0036 %
0037 %
0038 % Please cite the Manopt paper as well as the research paper:
0039 %
0040 %     @techreport{mishra2019riemannian,
0041 %       title={Riemannian optimization on the simplex of positive definite matrices},
0042 %       author={Mishra, B. and Kasai, H. and Jawanpuria, P.},
0043 %       institution={arXiv preprint arXiv:1906.10436},
0044 %       year={2019}
0045 %     }
0046 %
0047 % See also sympositivedefinitesimplexcomplexfactory multinomialfactory sympositivedefinitefactory
0048     
0049     % This file is part of Manopt: www.manopt.org.
0050     % Original author: Bamdev Mishra, September 18, 2019.
0051     % Contributors: NB
0052     % Change log:   Comments updated, 16 Dec 2019
0053     %               Removed typos in Hessian expression, 01 Nov, 2021
0054     
0055     symm = @(X) .5*(X+X');
0056     
0057     M.name = @() sprintf('%d complex hemitian positive definite matrices of size %dx%d such that their sum is the identiy matrix.', k, n, n);
0058     
0059     M.dim = @() (k-1)*n*(n+1);
0060     
0061     % Helpers to avoid computing full matrices simply to extract their trace
0062     vec     = @(A) A(:);
0063     trinner = @(A, B) real(vec(A')'*vec(B));  % = trace(A*B)
0064     trnorm  = @(A) sqrt((trinner(A, A))); % = sqrt(trace(A^2))
0065     
0066     
0067     % Choice of the metric on the orthonormal space is motivated by the
0068     % symmetry present in the space. The metric on the positive definite
0069     % cone is its natural bi-invariant metric.
0070     % The result is equal to: trace( (X\eta) * (X\zeta) )
0071     M.inner = @innerproduct;
0072     function iproduct = innerproduct(X, eta, zeta)
0073         iproduct = 0;
0074         for kk = 1 : k
0075             iproduct = iproduct + (trinner(X(:,:,kk)\eta(:,:,kk), X(:,:,kk)\zeta(:,:,kk))); % BM okay
0076         end
0077     end
0078     
0079     % Notice that X\eta is *not* symmetric in general.
0080     % The result is equal to: sqrt(trace((X\eta)^2))
0081     % There should be no need to take the real part, but rounding errors
0082     % may cause a small imaginary part to appear, so we discard it.
0083     M.norm = @innernorm;
0084     function inorm = innernorm(X, eta)
0085         inorm = 0;
0086         for kk = 1:k
0087             inorm = inorm + (trnorm(X(:,:,kk)\eta(:,:,kk)))^2; % BM okay
0088         end
0089         inorm = sqrt(inorm);
0090     end
0091     
0092     %     % Same here: X\Y is not symmetric in general.
0093     %     % Same remark about taking the real part.
0094     %     M.dist = @innerdistance;
0095     %     function idistance = innerdistance(X, Y)
0096     %       idistance = 0;
0097     %       for kk = 1:k
0098     %           idistance = idistance + real(trnorm(real(logm(X(:,:,kk)\Y(:,:,kk))))); % BM okay, but need not be correct.
0099     %       end
0100     %     end
0101     
0102     M.typicaldist = @() sqrt(k*n*(n+1)); % BM: to be looked into.
0103     
0104     
0105     M.egrad2rgrad = @egrad2rgrad;
0106     function rgrad = egrad2rgrad(X, egrad)
0107         egradscaled = nan(size(egrad));
0108         for kk = 1:k
0109             egradscaled(:,:,kk) = X(:,:,kk)*symm(egrad(:,:,kk))*X(:,:,kk);
0110         end
0111         
0112         % Project onto the set X1dot + X2dot + ... = 0.
0113         % That is rgrad = Xk*egradk*Xk + Xk*Lambdasol*Xk
0114         rgrad = M.proj(X, egradscaled);
0115         
0116         %   % Debug
0117         %   norm(sum(rgrad,3), 'fro') % BM: this should be zero.
0118     end
0119     
0120     
0121     M.ehess2rhess = @ehess2rhess;
0122     function Hess = ehess2rhess(X, egrad, ehess, eta)
0123 
0124         Hess = nan(size(X));
0125         
0126         egradscaled = nan(size(egrad));
0127         egradscaleddot = nan(size(egrad));
0128         for kk = 1:k
0129             egradk = symm(egrad(:,:,kk));
0130             ehessk = symm(ehess(:,:,kk));
0131             Xk = X(:,:,kk);
0132             etak = eta(:,:,kk);
0133 
0134             egradscaled(:,:,kk) = Xk*egradk*Xk;
0135             egradscaleddot(:,:,kk) = Xk*ehessk*Xk + 2*symm(etak*egradk*Xk);
0136         end
0137 
0138         % Compute Lambdasol
0139         RHS = - sum(egradscaled,3);
0140         [Lambdasol] = mylinearsolve(X, RHS);
0141 
0142 
0143         % Compute Lambdasoldot
0144         temp = nan(size(egrad));;
0145         for kk = 1:k
0146             Xk = X(:,:,kk);
0147             etak = eta(:,:,kk);
0148 
0149             temp(:,:,kk) = 2*symm(etak*Lambdasol*Xk);
0150         end
0151         RHSdot = - sum(egradscaleddot,3) - sum(temp,3);
0152         [Lambdasoldot] = mylinearsolve(X, RHSdot);
0153 
0154 
0155         for kk = 1:k
0156             egradk = symm(egrad(:,:,kk));
0157             ehessk = symm(ehess(:,:,kk));
0158             Xk = X(:,:,kk);
0159             etak = eta(:,:,kk);
0160 
0161             % Directional derivatives of the Riemannian gradient
0162             % Note that Riemannian grdient is Xk*egradk*Xk + Xk*Lambdasol*Xk.
0163             % rhessk = Xk*(ehessk + Lambdasoldot)*Xk + 2*symm(etak*(egradk + Lambdasol)*Xk);
0164             % rhessk = rhessk - symm(etak*(egradk + Lambdasol)*Xk);
0165             rhessk = Xk*(ehessk + Lambdasoldot)*Xk + symm(etak*(egradk + Lambdasol)*Xk);
0166 
0167             Hess(:,:,kk) = rhessk;
0168         end
0169         
0170         % Project onto the set X1dot + X2dot + ... = 0.
0171         Hess = M.proj(X, Hess);
0172 
0173 
0174         % Hess = nan(size(X));
0175         % for kk = 1 : k
0176         %     % % Directional derivatives of the Riemannian gradient
0177         %     % Hess(:,:,kk) = symm(X(:,:,kk)*symm(ehess(:,:,kk))*X(:,:,kk)) + 2*symm(eta(:,:,kk)*symm(egrad(:,:,kk))*X(:,:,kk));
0178             
0179         %     % % Correction factor for the non-constant metric
0180         %     % Hess(:,:,kk) = Hess(:,:,kk) - symm(eta(:,:,kk)*symm(egrad(:,:,kk))*X(:,:,kk));
0181             
0182         %     Hess(:,:,kk) = symm(X(:,:,kk)*symm(ehess(:,:,kk))*X(:,:,kk)) + symm(eta(:,:,kk)*symm(egrad(:,:,kk))*X(:,:,kk));
0183         % end
0184         
0185         % % Project onto the set X1dot + X2dot + ... = 0.
0186         % Hess = M.proj(X, Hess);
0187         
0188     end
0189     
0190     
0191     % Project onto the set X1dot + X2dot + ... = 0.
0192     M.proj = @innerprojection;
0193     function zeta = innerprojection(X, eta)
0194         % etareal = real(eta);
0195      %    etaimag = imag(eta);
0196      %    sumetareal = sum(etareal,3);
0197      %    sumetaimag = sum(etaimag,3);
0198 
0199         RHS = -sum(eta,3);
0200         
0201         Lambdasol = mylinearsolve(X, RHS);
0202         
0203         zeta = zeros(size(eta));
0204         for jj = 1 : k
0205             zeta(:,:,jj) = eta(:,:,jj) + (X(:,:,jj)*Lambdasol*X(:,:,jj));
0206         end
0207         
0208         % % Debug
0209         % eta;
0210         % sum(real(zeta),3)
0211         % sum(imag(zeta),3)
0212         % neta = eta - zeta;
0213         % innerproduct(X, zeta, neta) % This should be zero
0214     end
0215 
0216 
0217     function Lambdasol = mylinearsolve(X, RHS)
0218         % Solve the linear system.
0219         tol_omegax_pcg = 1e-8;
0220         max_iterations_pcg = 100;
0221         
0222         sumetareal = real(RHS);
0223         sumetaimag = imag(RHS);
0224         
0225         rhs = [sumetareal(:); sumetaimag(:)];
0226         
0227         [lambdasol, ~, ~, ~] = pcg(@compute_matrix_system, rhs, tol_omegax_pcg, max_iterations_pcg);
0228         
0229         lambdasolreal = lambdasol(1:n^2);
0230         lambdasolimag = lambdasol(n^2  + 1 : end);
0231         
0232         Lambdasol = symm(reshape(lambdasolreal, [n n])) + 1i*reshape(lambdasolimag,n,n);
0233         
0234         function lhslambda = compute_matrix_system(lambda)
0235             lambdareal = lambda(1:n^2);
0236             lambdaimag = lambda(n^2 + 1 : end);
0237             Lambda = symm(reshape(lambdareal, [n n])) + 1i*reshape(lambdaimag, n, n);
0238             lhsLambda = zeros(n,n);
0239             for kk = 1 : k
0240                 lhsLambda = lhsLambda + ((X(:,:,kk)*Lambda*X(:,:,kk)));
0241             end
0242             lhsLambdareal = real(lhsLambda);
0243             lhsLambdaimag = imag(lhsLambda);
0244             lhslambda = [lhsLambdareal(:); lhsLambdaimag(:)];
0245         end
0246     end
0247 
0248     
0249     M.tangent = M.proj;
0250     M.tangent2ambient = @(X, eta) eta;
0251     
0252     myeps = eps;
0253     
0254     M.retr = @retraction;
0255     function Y = retraction(X, eta, t) % BM okay
0256         if nargin < 3
0257             teta = eta;
0258         else
0259             teta = t*eta;
0260         end
0261         % The symm() call is mathematically unnecessary but numerically
0262         % necessary.
0263         Y = zeros(size(X));
0264         for kk=1:k
0265             % Second-order approximation of expm
0266             Y(:,:,kk) = symm(X(:,:,kk) + teta(:,:,kk) + .5*teta(:,:,kk)*((X(:,:,kk) + myeps*eye(n) )\teta(:,:,kk)));
0267         end
0268         Ysum = sum(Y, 3);
0269         Ysumsqrt = sqrtm(Ysum);
0270         for kk=1:kk
0271             Y(:,:,kk) = symm((Ysumsqrt\Y(:,:,kk))/Ysumsqrt);
0272         end
0273         % % Debug
0274         % norm(sum(Y, 3) - eye(n), 'fro') % This should be zero
0275     end
0276     
0277     M.exp = @exponential;
0278     function Y = exponential(X, eta, t)
0279         if nargin < 3
0280             t = 1.0;
0281         end
0282         Y = retraction(X, eta, t);
0283         warning('manopt:sympositivedefinitesimplexcomplexfactory:exp', ...
0284             ['Exponential for the Simplex' ...
0285             'manifold not implemented yet. Used retraction instead.']);
0286     end
0287     
0288     M.hash = @(X) ['z' hashmd5([real(X(:)); imag(X(:))])];% BM okay
0289     
0290     % Generate a random symmetric positive definite matrix following a
0291     % certain distribution. The particular choice of a distribution is of
0292     % course arbitrary, and specific applications might require different
0293     % ones.
0294     M.rand = @random;
0295     function X = random()
0296         X = nan(n,n,k);
0297         for kk = 1:k
0298             D = diag(1+rand(n, 1));
0299             [Q, R] = qr(randn(n) +1i*randn(n)); % BM okay
0300             X(:,:,kk) = Q*D*Q';
0301         end
0302         Xsum = sum(X, 3);
0303         Xsumsqrt = sqrtm(Xsum);
0304         for kk = 1 : k
0305             X(:,:,kk) = symm((Xsumsqrt\X(:,:,kk))/Xsumsqrt); % To do
0306         end
0307     end
0308     
0309     % Generate a uniformly random unit-norm tangent vector at X.
0310     M.randvec = @randomvec;
0311     function eta = randomvec(X)
0312         eta = nan(size(X));
0313         for kk = 1:k
0314             eta(:,:,kk) = symm(randn(n,n) + 1i*randn(n, n)); % BM okay
0315         end
0316         eta = M.proj(X, eta); % To do
0317         nrm = M.norm(X, eta);
0318         eta = eta / nrm;
0319     end
0320     
0321     M.lincomb = @matrixlincomb; % BM okay
0322     
0323     M.zerovec = @(X) zeros(n,n,k); % BM okay
0324     
0325     % Poor man's vector transport: exploit the fact that all tangent spaces
0326     % are the set of symmetric matrices, so that the identity is a sort of
0327     % vector transport. It may perform poorly if the origin and target (X1
0328     % and X2) are far apart though. This should not be the case for typical
0329     % optimization algorithms, which perform small steps.
0330     M.transp = @(X1, X2, eta) M.proj(X2, eta);% To do
0331     
0332     % vec and mat are not isometries, because of the unusual inner metric.
0333     M.vec = @(X, U) [real(U(:)); image(U(:))] ; % BM okay
0334     M.mat = @(X, u) reshape(u(1:(n*n*k)) + 1i*u((n*n*k+1):end), n, n, k); % BM okay
0335     M.vecmatareisometries = @() false;
0336     
0337 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005