Home > manopt > tools > productmanifold.m

productmanifold

PURPOSE ^

Returns a structure describing a product manifold M = M1 x M2 x ... x Mn.

SYNOPSIS ^

function M = productmanifold(elements)

DESCRIPTION ^

 Returns a structure describing a product manifold M = M1 x M2 x ... x Mn.

 function M = productmanifold(elements)

 Input: an elements structure such that each field contains a manifold
 structure.
 
 Output: a manifold structure M representing the manifold obtained by
 taking the Cartesian product of the manifolds described in the elements
 structure, with the metric obtainded by element-wise extension. Points
 and vectors are stored as structures with the same fieldnames as in
 elements.

 Example:
 M = productmanifold(struct('X', spherefactory(3), 'Y', spherefactory(4)))
 disp(M.name());
 x = M.rand()

 Points of M = S^2 x S^3 are represented as structures with two fields, X
 and Y. The values associated to X are points of S^2, and likewise points
 of S^3 for the field Y. Tangent vectors are also represented as
 structures with two corresponding fields X and Y.
 
 See also: powermanifold

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = productmanifold(elements)
0002 % Returns a structure describing a product manifold M = M1 x M2 x ... x Mn.
0003 %
0004 % function M = productmanifold(elements)
0005 %
0006 % Input: an elements structure such that each field contains a manifold
0007 % structure.
0008 %
0009 % Output: a manifold structure M representing the manifold obtained by
0010 % taking the Cartesian product of the manifolds described in the elements
0011 % structure, with the metric obtainded by element-wise extension. Points
0012 % and vectors are stored as structures with the same fieldnames as in
0013 % elements.
0014 %
0015 % Example:
0016 % M = productmanifold(struct('X', spherefactory(3), 'Y', spherefactory(4)))
0017 % disp(M.name());
0018 % x = M.rand()
0019 %
0020 % Points of M = S^2 x S^3 are represented as structures with two fields, X
0021 % and Y. The values associated to X are points of S^2, and likewise points
0022 % of S^3 for the field Y. Tangent vectors are also represented as
0023 % structures with two corresponding fields X and Y.
0024 %
0025 % See also: powermanifold
0026 
0027 % This file is part of Manopt: www.manopt.org.
0028 % Original author: Nicolas Boumal, Dec. 30, 2012.
0029 % Contributors:
0030 % Change log:
0031 %   NB, July 4, 2013: Added support for vec, mat, tangent.
0032 %                     Added support for egrad2rgrad and ehess2rhess.
0033 %                     Modified hash function to make hash strings shorter.
0034 
0035 
0036     elems = fieldnames(elements);
0037     nelems = numel(elems);
0038     
0039     assert(nelems >= 1, ...
0040            'elements must be a structure with at least one field.');
0041     
0042     M.name = @name;
0043     function str = name()
0044         str = 'Product manifold: ';
0045         str = [str sprintf('[%s: %s]', ...
0046                            elems{1}, elements.(elems{1}).name())];
0047         for i = 2 : nelems
0048             str = [str sprintf(' x [%s: %s]', ...
0049                    elems{i}, elements.(elems{i}).name())]; %#ok<AGROW>
0050         end
0051     end
0052     
0053     M.dim = @dim;
0054     function d = dim()
0055         d = 0;
0056         for i = 1 : nelems
0057             d = d + elements.(elems{i}).dim();
0058         end
0059     end
0060     
0061     M.inner = @inner;
0062     function val = inner(x, u, v)
0063         val = 0;
0064         for i = 1 : nelems
0065             val = val + elements.(elems{i}).inner(x.(elems{i}), ...
0066                                                u.(elems{i}), v.(elems{i}));
0067         end
0068     end
0069 
0070     M.norm = @(x, d) sqrt(M.inner(x, d, d));
0071 
0072     M.dist = @dist;
0073     function d = dist(x, y)
0074         sqd = 0;
0075         for i = 1 : nelems
0076             sqd = sqd + elements.(elems{i}).dist(x.(elems{i}), ...
0077                                                  y.(elems{i}))^2;
0078         end
0079         d = sqrt(sqd);
0080     end
0081     
0082     M.typicaldist = @typicaldist;
0083     function d = typicaldist
0084         sqd = 0;
0085         for i = 1 : nelems
0086             sqd = sqd + elements.(elems{i}).typicaldist()^2;
0087         end
0088         d = sqrt(sqd);
0089     end
0090 
0091     M.proj = @proj;
0092     function v = proj(x, u)
0093         for i = 1 : nelems
0094             v.(elems{i}) = elements.(elems{i}).proj(x.(elems{i}), ...
0095                                                     u.(elems{i}));
0096         end
0097     end
0098 
0099     M.tangent = @tangent;
0100     function v = tangent(x, u)
0101         for i = 1 : nelems
0102             v.(elems{i}) = elements.(elems{i}).tangent(x.(elems{i}), ...
0103                                                        u.(elems{i}));
0104         end
0105     end
0106 
0107     M.tangent2ambient = @tangent2ambient;
0108     function v = tangent2ambient(x, u)
0109         for i = 1 : nelems
0110             if isfield(elements.(elems{i}), 'tangent2ambient')
0111                 v.(elems{i}) = ...
0112                     elements.(elems{i}).tangent2ambient( ...
0113                                                x.(elems{i}), u.(elems{i}));
0114             else
0115                 v.(elems{i}) = u.(elems{i});
0116             end
0117         end
0118     end
0119 
0120     M.egrad2rgrad = @egrad2rgrad;
0121     function g = egrad2rgrad(x, g)
0122         for i = 1 : nelems
0123             g.(elems{i}) = elements.(elems{i}).egrad2rgrad(...
0124                                                x.(elems{i}), g.(elems{i}));
0125         end
0126     end
0127 
0128     M.ehess2rhess = @ehess2rhess;
0129     function h = ehess2rhess(x, eg, eh, h)
0130         for i = 1 : nelems
0131             h.(elems{i}) = elements.(elems{i}).ehess2rhess(...
0132                  x.(elems{i}), eg.(elems{i}), eh.(elems{i}), h.(elems{i}));
0133         end
0134     end
0135     
0136     M.exp = @exp;
0137     function y = exp(x, u, t)
0138         if nargin < 3
0139             t = 1.0;
0140         end
0141         for i = 1 : nelems
0142             y.(elems{i}) = elements.(elems{i}).exp(x.(elems{i}), ...
0143                                                    u.(elems{i}), t);
0144         end
0145     end
0146     
0147     M.retr = @retr;
0148     function y = retr(x, u, t)
0149         if nargin < 3
0150             t = 1.0;
0151         end
0152         for i = 1 : nelems
0153             y.(elems{i}) = elements.(elems{i}).retr(x.(elems{i}), ...
0154                                                     u.(elems{i}), t);
0155         end
0156     end
0157     
0158     M.log = @log;
0159     function u = log(x1, x2)
0160         for i = 1 : nelems
0161             u.(elems{i}) = elements.(elems{i}).log(x1.(elems{i}), ...
0162                                                    x2.(elems{i}));
0163         end
0164     end
0165 
0166     M.hash = @hash;
0167     function str = hash(x)
0168         str = '';
0169         for i = 1 : nelems
0170             str = [str elements.(elems{i}).hash(x.(elems{i}))]; %#ok<AGROW>
0171         end
0172         str = ['z' hashmd5(str)];
0173     end
0174 
0175     M.lincomb = @lincomb;
0176     function v = lincomb(x, a1, u1, a2, u2)
0177         if nargin == 3
0178             for i = 1 : nelems
0179                 v.(elems{i}) = elements.(elems{i}).lincomb(x.(elems{i}), ...
0180                                                         a1, u1.(elems{i}));
0181             end
0182         elseif nargin == 5
0183             for i = 1 : nelems
0184                 v.(elems{i}) = elements.(elems{i}).lincomb(x.(elems{i}), ...
0185                                      a1, u1.(elems{i}), a2, u2.(elems{i}));
0186             end
0187         else
0188             error('Bad usage of productmanifold.lincomb');
0189         end
0190     end
0191 
0192     M.rand = @rand;
0193     function x = rand()
0194         for i = 1 : nelems
0195             x.(elems{i}) = elements.(elems{i}).rand();
0196         end
0197     end
0198 
0199     M.randvec = @randvec;
0200     function u = randvec(x)
0201         for i = 1 : nelems
0202             u.(elems{i}) = elements.(elems{i}).randvec(x.(elems{i}));
0203         end
0204         u = M.lincomb(x, 1/sqrt(nelems), u);
0205     end
0206 
0207     M.zerovec = @zerovec;
0208     function u = zerovec(x)
0209         for i = 1 : nelems
0210             u.(elems{i}) = elements.(elems{i}).zerovec(x.(elems{i}));
0211         end
0212     end
0213 
0214     M.transp = @transp;
0215     function v = transp(x1, x2, u)
0216         for i = 1 : nelems
0217             v.(elems{i}) = elements.(elems{i}).transp(x1.(elems{i}), ...
0218                                               x2.(elems{i}), u.(elems{i}));
0219         end
0220     end
0221 
0222     M.pairmean = @pairmean;
0223     function y = pairmean(x1, x2)
0224         for i = 1 : nelems
0225             y.(elems{i}) = elements.(elems{i}).pairmean(x1.(elems{i}), ...
0226                                                         x2.(elems{i}));
0227         end
0228     end
0229 
0230 
0231     % Gather the length of the column vector representations of tangent
0232     % vectors for each of the manifolds. Raise a flag if any of the base
0233     % manifolds has no vec function available.
0234     vec_available = true;
0235     vec_lens = zeros(nelems, 1);
0236     for ii = 1 : nelems
0237         Mi = elements.(elems{ii});
0238         if isfield(Mi, 'vec')
0239             rand_x = Mi.rand();
0240             zero_u = Mi.zerovec(rand_x);
0241             vec_lens(ii) = length(Mi.vec(rand_x, zero_u));
0242         else
0243             vec_available = false;
0244             break;
0245         end
0246     end
0247     vec_pos = cumsum([1 ; vec_lens]);
0248     
0249     if vec_available
0250         M.vec = @vec;
0251         M.mat = @mat;
0252     end
0253     
0254     function u_vec = vec(x, u_mat)
0255         u_vec = zeros(vec_pos(end)-1, 1);
0256         for i = 1 : nelems
0257             range = vec_pos(i) : (vec_pos(i+1)-1);
0258             u_vec(range) = elements.(elems{i}).vec(x.(elems{i}), ...
0259                                                    u_mat.(elems{i}));
0260         end
0261     end
0262 
0263     function u_mat = mat(x, u_vec)
0264         u_mat = struct();
0265         for i = 1 : nelems
0266             range = vec_pos(i) : (vec_pos(i+1)-1);
0267             u_mat.(elems{i}) = elements.(elems{i}).mat(x.(elems{i}), ...
0268                                                        u_vec(range));
0269         end
0270     end
0271 
0272     vecmatareisometries = true;
0273     for ii = 1 : nelems
0274         if ~isfield(elements.(elems{ii}), 'vecmatareisometries') || ...
0275            ~elements.(elems{ii}).vecmatareisometries()
0276             vecmatareisometries = false;
0277             break;
0278         end
0279     end
0280     M.vecmatareisometries = @() vecmatareisometries;    
0281 
0282 end

Generated on Sat 12-Nov-2016 14:11:22 by m2html © 2005