Home > manopt > manifolds > stiefel > stiefelfactory.m

stiefelfactory

PURPOSE ^

Returns a manifold structure to optimize over orthonormal matrices.

SYNOPSIS ^

function M = stiefelfactory(n, p, k)

DESCRIPTION ^

 Returns a manifold structure to optimize over orthonormal matrices.

 function M = stiefelfactory(n, p)
 function M = stiefelfactory(n, p, k)

 The Stiefel manifold is the set of orthonormal nxp matrices. If k
 is larger than 1, this is the Cartesian product of the Stiefel manifold
 taken k times. The metric is such that the manifold is a Riemannian
 submanifold of R^nxp equipped with the usual trace inner product, that
 is, it is the usual metric.

 Points are represented as matrices X of size n x p x k (or n x p if k=1,
 which is the default) such that each n x p matrix is orthonormal,
 i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
 i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
 size as points.

 By default, k = 1.

 See also: grassmannfactory rotationsfactory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = stiefelfactory(n, p, k)
0002 % Returns a manifold structure to optimize over orthonormal matrices.
0003 %
0004 % function M = stiefelfactory(n, p)
0005 % function M = stiefelfactory(n, p, k)
0006 %
0007 % The Stiefel manifold is the set of orthonormal nxp matrices. If k
0008 % is larger than 1, this is the Cartesian product of the Stiefel manifold
0009 % taken k times. The metric is such that the manifold is a Riemannian
0010 % submanifold of R^nxp equipped with the usual trace inner product, that
0011 % is, it is the usual metric.
0012 %
0013 % Points are represented as matrices X of size n x p x k (or n x p if k=1,
0014 % which is the default) such that each n x p matrix is orthonormal,
0015 % i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
0016 % i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
0017 % size as points.
0018 %
0019 % By default, k = 1.
0020 %
0021 % See also: grassmannfactory rotationsfactory
0022 
0023 % This file is part of Manopt: www.manopt.org.
0024 % Original author: Nicolas Boumal, Dec. 30, 2012.
0025 % Contributors:
0026 % Change log:
0027 %  July  5, 2013 (NB) : Added ehess2rhess.
0028 %  Jan. 27, 2014 (BM) : Bug in ehess2rhess corrected.
0029 %  June 24, 2014 (NB) : Added true exponential map and changed the randvec
0030 %                       function so that it now returns a globally
0031 %                       normalized vector, not a vector where each
0032 %                       component is normalized (this only matters if k>1).
0033 
0034     
0035     if ~exist('k', 'var') || isempty(k)
0036         k = 1;
0037     end
0038     
0039     if k == 1
0040         M.name = @() sprintf('Stiefel manifold St(%d, %d)', n, p);
0041     elseif k > 1
0042         M.name = @() sprintf('Product Stiefel manifold St(%d, %d)^%d', n, p, k);
0043     else
0044         error('k must be an integer no less than 1.');
0045     end
0046     
0047     M.dim = @() k*(n*p - .5*p*(p+1));
0048     
0049     M.inner = @(x, d1, d2) d1(:).'*d2(:);
0050     
0051     M.norm = @(x, d) norm(d(:));
0052     
0053     M.dist = @(x, y) error('stiefel.dist not implemented yet.');
0054     
0055     M.typicaldist = @() sqrt(p*k);
0056     
0057     M.proj = @projection;
0058     function Up = projection(X, U)
0059         
0060         XtU = multiprod(multitransp(X), U);
0061         symXtU = multisym(XtU);
0062         Up = U - multiprod(X, symXtU);
0063         
0064 % The code above is equivalent to, but much faster than, the code below.
0065 %
0066 %     Up = zeros(size(U));
0067 %     function A = sym(A), A = .5*(A+A'); end
0068 %     for i = 1 : k
0069 %         Xi = X(:, :, i);
0070 %         Ui = U(:, :, i);
0071 %         Up(:, :, i) = Ui - Xi*sym(Xi'*Ui);
0072 %     end
0073 
0074     end
0075     
0076     M.tangent = M.proj;
0077     
0078     % For Riemannian submanifolds, converting a Euclidean gradient into a
0079     % Riemannian gradient amounts to an orthogonal projection.
0080     M.egrad2rgrad = M.proj;
0081     
0082     M.ehess2rhess = @ehess2rhess;
0083     function rhess = ehess2rhess(X, egrad, ehess, H)
0084         XtG = multiprod(multitransp(X), egrad);
0085         symXtG = multisym(XtG);
0086         HsymXtG = multiprod(H, symXtG);
0087         rhess = projection(X, ehess - HsymXtG);
0088     end
0089     
0090     M.retr = @retraction;
0091     function Y = retraction(X, U, t)
0092         if nargin < 3
0093             t = 1.0;
0094         end
0095         Y = X + t*U;
0096         for i = 1 : k
0097             [Q, R] = qr(Y(:, :, i), 0);
0098             % The instruction with R assures we are not flipping signs
0099             % of some columns, which should never happen in modern Matlab
0100             % versions but may be an issue with older versions.
0101             Y(:, :, i) = Q * diag(sign(sign(diag(R))+.5));
0102         end
0103     end
0104     
0105     M.exp = @exponential;
0106     function Y = exponential(X, U, t)
0107         if nargin == 2
0108             t = 1;
0109         end
0110         tU = t*U;
0111         Y = zeros(size(X));
0112         for i = 1 : k
0113             % From a formula by Ross Lippert, Example 5.4.2 in AMS08.
0114             Xi = X(:, :, i);
0115             Ui = tU(:, :, i);
0116             Y(:, :, i) = [Xi Ui] * ...
0117                          expm([Xi'*Ui , -Ui'*Ui ; eye(p) , Xi'*Ui]) * ...
0118                          [ expm(-Xi'*Ui) ; zeros(p) ];
0119         end
0120         
0121     end
0122 
0123     M.hash = @(X) ['z' hashmd5(X(:))];
0124     
0125     M.rand = @random;
0126     function X = random()
0127         X = zeros(n, p, k);
0128         for i = 1 : k
0129             [Q, unused] = qr(randn(n, p), 0); %#ok<NASGU>
0130             X(:, :, i) = Q;
0131         end
0132     end
0133     
0134     M.randvec = @randomvec;
0135     function U = randomvec(X)
0136         U = projection(X, randn(n, p, k));
0137         U = U / norm(U(:));
0138     end
0139     
0140     M.lincomb = @matrixlincomb;
0141     
0142     M.zerovec = @(x) zeros(n, p, k);
0143     
0144     M.transp = @(x1, x2, d) projection(x2, d);
0145     
0146     M.vec = @(x, u_mat) u_mat(:);
0147     M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]);
0148     M.vecmatareisometries = @() true;
0149 
0150 end

Generated on Fri 08-Sep-2017 12:43:19 by m2html © 2005