Home > manopt > manifolds > grassmann > grassmannfactory.m

grassmannfactory

PURPOSE ^

Returns a manifold struct to optimize over the space of vector subspaces.

SYNOPSIS ^

function M = grassmannfactory(n, p, k, gpuflag)

DESCRIPTION ^

 Returns a manifold struct to optimize over the space of vector subspaces.

 function M = grassmannfactory(n, p)
 function M = grassmannfactory(n, p, k)
 function M = grassmannfactory(n, p, k, gpuflag)

 Grassmann manifold: each point on this manifold is a collection of k
 vector subspaces of dimension p embedded in R^n.

 The metric is obtained by making the Grassmannian a Riemannian quotient
 manifold of the Stiefel manifold, i.e., the manifold of orthonormal
 matrices, itself endowed with a metric by making it a Riemannian
 submanifold of the Euclidean space, endowed with the usual inner product.
 In short: it is the usual metric used in most cases.
 
 This structure deals with matrices X of size n x p x k (or n x p if
 k = 1, which is the default) such that each n x p matrix is orthonormal,
 i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
 i = 1 : k if k > 1. Each n x p matrix is a numerical representation of
 the vector subspace its columns span.

 The retraction is based on a polar factorization and is second order.

 Set gpuflag = true to have points, tangent vectors and ambient vectors
 stored on the GPU. If so, computations can be done on the GPU directly.

 By default, k = 1 and gpuflag = false.

 See also: stiefelfactory grassmanncomplexfactory grassmanngeneralizedfactory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = grassmannfactory(n, p, k, gpuflag)
0002 % Returns a manifold struct to optimize over the space of vector subspaces.
0003 %
0004 % function M = grassmannfactory(n, p)
0005 % function M = grassmannfactory(n, p, k)
0006 % function M = grassmannfactory(n, p, k, gpuflag)
0007 %
0008 % Grassmann manifold: each point on this manifold is a collection of k
0009 % vector subspaces of dimension p embedded in R^n.
0010 %
0011 % The metric is obtained by making the Grassmannian a Riemannian quotient
0012 % manifold of the Stiefel manifold, i.e., the manifold of orthonormal
0013 % matrices, itself endowed with a metric by making it a Riemannian
0014 % submanifold of the Euclidean space, endowed with the usual inner product.
0015 % In short: it is the usual metric used in most cases.
0016 %
0017 % This structure deals with matrices X of size n x p x k (or n x p if
0018 % k = 1, which is the default) such that each n x p matrix is orthonormal,
0019 % i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
0020 % i = 1 : k if k > 1. Each n x p matrix is a numerical representation of
0021 % the vector subspace its columns span.
0022 %
0023 % The retraction is based on a polar factorization and is second order.
0024 %
0025 % Set gpuflag = true to have points, tangent vectors and ambient vectors
0026 % stored on the GPU. If so, computations can be done on the GPU directly.
0027 %
0028 % By default, k = 1 and gpuflag = false.
0029 %
0030 % See also: stiefelfactory grassmanncomplexfactory grassmanngeneralizedfactory
0031 
0032 % This file is part of Manopt: www.manopt.org.
0033 % Original author: Nicolas Boumal, Dec. 30, 2012.
0034 % Contributors:
0035 % Change log:
0036 %   March 22, 2013 (NB):
0037 %       Implemented geodesic distance.
0038 %
0039 %   April 17, 2013 (NB):
0040 %       Retraction changed to the polar decomposition, so that the vector
0041 %       transport is now correct, in the sense that it is compatible with
0042 %       the retraction, i.e., transporting a tangent vector G from U to V
0043 %       where V = Retr(U, H) will give Z, and transporting GQ from UQ to VQ
0044 %       will give ZQ: there is no dependence on the representation, which
0045 %       is as it should be. Notice that the polar factorization requires an
0046 %       SVD whereas the qfactor retraction requires a QR decomposition,
0047 %       which is cheaper. Hence, if the retraction happens to be a
0048 %       bottleneck in your application and you are not using vector
0049 %       transports, you may want to replace the retraction with a qfactor.
0050 %
0051 %   July  4, 2013 (NB):
0052 %       Added support for the logarithmic map 'log'.
0053 %
0054 %   July  5, 2013 (NB):
0055 %       Added support for ehess2rhess.
0056 %
0057 %   June 24, 2014 (NB):
0058 %       Small bug fix in the retraction, and added final
0059 %       re-orthonormalization at the end of the exponential map. This
0060 %       follows discussions on the forum where it appeared there is a
0061 %       significant loss in orthonormality without that extra step. Also
0062 %       changed the randvec function so that it now returns a globally
0063 %       normalized vector, not a vector where each component is normalized
0064 %       (this only matters if k>1).
0065 %
0066 %   July 8, 2018 (NB):
0067 %       Inverse retraction implemented.
0068 %
0069 %   Aug. 3, 2018 (NB):
0070 %       Added GPU support: just set gpuflag = true.
0071 %
0072 %   Apr. 19, 2019 (NB):
0073 %       ehess2rhess: to ensure horizontality, it makes sense to project
0074 %       last, same as in stiefelfactory.
0075 %
0076 %   May 3, 2019 (NB):
0077 %       Added explanation about vector transport relation to retraction.
0078 %
0079 %   Nov. 13, 2019 (NB):
0080 %       Added pairmean function.
0081 %
0082 %   Jan. 8, 2021 (NB)
0083 %       Added tangent2ambient/tangent2ambient_is_identity pair.
0084 %       Here, 'ambient' refers to the total space.
0085 
0086     assert(n >= p, ...
0087            ['The dimension n of the ambient space must be larger ' ...
0088             'than the dimension p of the subspaces.']);
0089     
0090     if ~exist('k', 'var') || isempty(k)
0091         k = 1;
0092     end
0093     if ~exist('gpuflag', 'var') || isempty(gpuflag)
0094         gpuflag = false;
0095     end
0096     
0097     % If gpuflag is active, new arrays (e.g., via rand, randn, zeros, ones)
0098     % are created directly on the GPU; otherwise, they are created in the
0099     % usual way (in double precision).
0100     if gpuflag
0101         array_type = 'gpuArray';
0102     else
0103         array_type = 'double';
0104     end
0105     
0106     if k == 1
0107         M.name = @() sprintf('Grassmann manifold Gr(%d, %d)', n, p);
0108     elseif k > 1
0109         M.name = @() sprintf('Multi Grassmann manifold Gr(%d, %d)^%d', ...
0110                              n, p, k);
0111     else
0112         error('k must be an integer no less than 1.');
0113     end
0114     
0115     M.dim = @() k*p*(n-p);
0116     
0117     M.inner = @(x, d1, d2) d1(:).'*d2(:);
0118     
0119     M.norm = @(x, d) norm(d(:));
0120     
0121     M.dist = @distance;
0122     function d = distance(x, y)
0123         square_d = 0;
0124         XtY = multiprod(multitransp(x), y);
0125         for kk = 1 : k
0126             cos_princ_angle = svd(XtY(:, :, kk));
0127             % For x and y closer than ~sqrt(eps), this function is
0128             % inaccurate, and typically returns values close to ~sqrt(eps).
0129             square_d = square_d + sum(real(acos(cos_princ_angle)).^2);
0130         end
0131         d = sqrt(square_d);
0132     end
0133     
0134     M.typicaldist = @() sqrt(p*k);
0135     
0136     % Orthogonal projection of an ambient vector U to the horizontal space
0137     % at X.
0138     M.proj = @projection;
0139     function Up = projection(X, U)
0140         
0141         XtU = multiprod(multitransp(X), U);
0142         Up = U - multiprod(X, XtU);
0143 
0144     end
0145     
0146     M.tangent = M.proj;
0147     
0148     M.tangent2ambient_is_identity = true;
0149     M.tangent2ambient = @(X, U) U;
0150     
0151     M.egrad2rgrad = M.proj;
0152     
0153     M.ehess2rhess = @ehess2rhess;
0154     function rhess = ehess2rhess(X, egrad, ehess, H)
0155         XtG = multiprod(multitransp(X), egrad);
0156         HXtG = multiprod(H, XtG);
0157         rhess = projection(X, ehess - HXtG);
0158     end
0159     
0160     M.retr = @retraction;
0161     function Y = retraction(X, U, t)
0162         if nargin < 3
0163             Y = X + U;
0164         else
0165             Y = X + t*U;
0166         end
0167         for kk = 1 : k
0168         
0169             % Compute the polar factorization of Y = X+tU
0170             [u, s, v] = svd(Y(:, :, kk), 'econ'); %#ok
0171             Y(:, :, kk) = u*v';
0172             
0173             % Another way to compute this retraction uses QR instead of SVD.
0174             % As compared with the Stiefel factory, we do not need to
0175             % worry about flipping signs of columns here, since only
0176             % the column space is important, not the actual columns.
0177             % We prefer the polar factor to the Q-factor computation for
0178             % reasons explained below: see M.transp.
0179             %
0180             % [Q, unused] = qr(Y(:, :, kk), 0); %#ok
0181             % Y(:, :, kk) = Q;
0182             
0183         end
0184     end
0185     
0186     % This inverse retraction is valid for both the QR retraction and the
0187     % polar retraction.
0188     M.invretr = @invretr;
0189     function U = invretr(X, Y)
0190         XtY = multiprod(multitransp(X), Y);
0191         U = zeros(n, p, k, array_type);
0192         for kk = 1 : k
0193             U(:, :, kk) = Y(:, :, kk) / XtY(:, :, kk);
0194         end
0195         U = U - X;
0196     end
0197     
0198     % See Eq. (2.65) in Edelman, Arias and Smith 1998.
0199     M.exp = @exponential;
0200     function Y = exponential(X, U, t)
0201         if nargin == 3
0202             tU = t*U;
0203         else
0204             tU = U;
0205         end
0206         Y = zeros(size(X), array_type);
0207         for kk = 1 : k
0208             [u, s, v] = svd(tU(:, :, kk), 0);
0209             cos_s = diag(cos(diag(s)));
0210             sin_s = diag(sin(diag(s)));
0211             Y(:, :, kk) = X(:, :, kk)*v*cos_s*v' + u*sin_s*v';
0212             % From numerical experiments, it seems necessary to
0213             % re-orthonormalize. This is overall quite expensive.
0214             [q, unused] = qr(Y(:, :, kk), 0); %#ok
0215             Y(:, :, kk) = q;
0216         end
0217     end
0218 
0219     % Test code for the logarithm:
0220     % Gr = grassmannfactory(5, 2, 3);
0221     % x = Gr.rand()
0222     % y = Gr.rand()
0223     % u = Gr.log(x, y)
0224     % Gr.dist(x, y) % These two numbers should
0225     % Gr.norm(x, u) % be the same.
0226     % z = Gr.exp(x, u) % z needs not be the same matrix as y, but it should
0227     % v = Gr.log(x, z) % be the same point as y on Grassmann: dist almost 0.
0228     M.log = @logarithm;
0229     function U = logarithm(X, Y)
0230         U = zeros(n, p, k, array_type);
0231         for kk = 1 : k
0232             x = X(:, :, kk);
0233             y = Y(:, :, kk);
0234             ytx = y.'*x;
0235             At = y.'-ytx*x.';
0236             Bt = ytx\At;
0237             [u, s, v] = svd(Bt.', 'econ');
0238 
0239             u = u(:, 1:p);
0240             s = diag(s);
0241             s = s(1:p);
0242             v = v(:, 1:p);
0243 
0244             U(:, :, kk) = u*diag(atan(s))*v.';
0245         end
0246     end
0247 
0248     M.hash = @(X) ['z' hashmd5(X(:))];
0249     
0250     M.rand = @random;
0251     function X = random()
0252         X = randn(n, p, k, array_type);
0253         for kk = 1 : k
0254             [Q, unused] = qr(X(:, :, kk), 0); %#ok
0255             X(:, :, kk) = Q;
0256         end
0257     end
0258     
0259     M.randvec = @randomvec;
0260     function U = randomvec(X)
0261         U = projection(X, randn(n, p, k, array_type));
0262         U = U / norm(U(:));
0263     end
0264     
0265     M.lincomb = @matrixlincomb;
0266     
0267     M.zerovec = @(x) zeros(n, p, k, array_type);
0268     
0269     % This transport is compatible with the polar retraction, in the
0270     % following sense:
0271     %
0272     % n = 7; p = 3;
0273     % Gr = grassmannfactory(n, p);
0274     % X = Gr.rand();
0275     % U = Gr.randvec(X);
0276     % V = Gr.randvec(X);
0277     % [Q, ~] = qr(randn(p));
0278     % Gr.transp(X*Q, Gr.retr(X*Q, V*Q), U*Q) % these two
0279     % Gr.transp(X, Gr.retr(X, V), U)*Q       % are equal (up to eps)
0280     %
0281     % That is, if we transport U, the horizontal lift of some tangent
0282     % vector at X, to Y, and Y = Retr_X(V) with V the horizontal lift of
0283     % some tangent vector at X, we get the horizontal lift of some tangent
0284     % vector at Y. If we displace X, U, V to XQ, UQ, VQ for some arbitrary
0285     % orthogonal matrix Q, we get a horizontal lift of some vector at YQ.
0286     % Importantly, these two vectors are the lifts of the same tangent
0287     % vector, only lifted at Y and YQ.
0288     %
0289     % However, this vector transport is /not/ fully invariant, in the sense
0290     % that transporting U from X to some arbitrary Y may well yield the
0291     % lift of a different vector when compared to transporting U from X
0292     % to YQ, where Q is an arbitrary orthogonal matrix, even though YQ is
0293     % equivalent to Y. Specifically:
0294     %
0295     % Y = Gr.rand();
0296     % Gr.transp(X, Y*Q, U) - Gr.transp(X, Y, U)*Q   % this is not zero.
0297     %
0298     % However, the following vectors are equal:
0299     %
0300     % Gr.transp(X, Y*Q, U) - Gr.transp(X, Y, U)     % this *is* zero.
0301     %
0302     % For this to be a proper vector transport from [X] to [Y] in general,
0303     % assuming X'Y is invertible, one should multiply the output of this
0304     % function on the right with the polar factor of X'*Y, that is,
0305     % multiply by u*v' where [u, s, v] = svd(X'*Y), for each slice.
0306     M.transp = @(X, Y, U) projection(Y, U);
0307     
0308     % The mean of two points is here defined as the midpoint of a
0309     % minimizing geodesic connecting the two points. If the log of (X1, X2)
0310     % is not uniquely defined, then the returned object may not be
0311     % meaningful; in other words: this works best if (X1, X2) are close.
0312     M.pairmean = @pairmean;
0313     function Y = pairmean(X1, X2)
0314         Y = M.exp(X1, .5*M.log(X1, X2));
0315     end
0316     
0317     M.vec = @(x, u_mat) u_mat(:);
0318     M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]);
0319     M.vecmatareisometries = @() true;
0320 
0321     
0322     % Automatically convert a number of tools to support GPU.
0323     if gpuflag
0324         M = factorygpuhelper(M);
0325     end
0326 
0327 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005