0001 function M = productmanifold(elements)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056 elems = fieldnames(elements);
0057 nelems = numel(elems);
0058
0059 assert(nelems >= 1, ...
0060 'elements must be a structure with at least one field.');
0061
0062
0063
0064
0065
0066
0067 vec_available = true;
0068 vec_lens = zeros(nelems, 1);
0069 for ii = 1 : nelems
0070 Mi = elements.(elems{ii});
0071 if isfield(Mi, 'vec')
0072 rand_x = Mi.rand();
0073 zero_u = Mi.zerovec(rand_x);
0074 vec_lens(ii) = length(Mi.vec(rand_x, zero_u));
0075 else
0076 vec_available = false;
0077 break;
0078 end
0079 end
0080 vec_pos = cumsum([1 ; vec_lens]);
0081
0082 vecmatareisometries = vec_available;
0083 for ii = 1 : nelems
0084 if ~isfield(elements.(elems{ii}), 'vecmatareisometries') || ...
0085 ~elements.(elems{ii}).vecmatareisometries()
0086 vecmatareisometries = false;
0087 break;
0088 end
0089 end
0090
0091
0092
0093
0094 M = productmanifoldhelper(elements, elems, nelems, vec_available, ...
0095 vec_pos, vecmatareisometries);
0096
0097 end
0098
0099
0100 function M = productmanifoldhelper(elements, elems, nelems, ...
0101 vec_available, vec_pos, ...
0102 vecmatareisometries)
0103
0104
0105 function answer = all_elements_provide(method_name)
0106 answer = false;
0107 for i = 1 : nelems
0108 if ~isfield(elements.(elems{i}), method_name)
0109 return;
0110 end
0111 end
0112 answer = true;
0113 end
0114
0115 M.name = @name;
0116 function str = name()
0117 str = 'Product manifold: ';
0118 str = [str sprintf('[%s: %s]', ...
0119 elems{1}, elements.(elems{1}).name())];
0120 for i = 2 : nelems
0121 str = [str sprintf(' x [%s: %s]', ...
0122 elems{i}, elements.(elems{i}).name())];
0123 end
0124 end
0125
0126 M.dim = @dim;
0127 function d = dim()
0128 d = 0;
0129 for i = 1 : nelems
0130 d = d + elements.(elems{i}).dim();
0131 end
0132 end
0133
0134 M.inner = @inner;
0135 function val = inner(x, u, v)
0136 val = 0;
0137 for i = 1 : nelems
0138 val = val + elements.(elems{i}).inner(x.(elems{i}), ...
0139 u.(elems{i}), v.(elems{i}));
0140 end
0141 end
0142
0143 M.norm = @(x, d) sqrt(M.inner(x, d, d));
0144
0145 if all_elements_provide('dist')
0146 M.dist = @dist;
0147 end
0148 function d = dist(x, y)
0149 sqd = 0;
0150 for i = 1 : nelems
0151 sqd = sqd + elements.(elems{i}).dist(x.(elems{i}), ...
0152 y.(elems{i}))^2;
0153 end
0154 d = sqrt(sqd);
0155 end
0156
0157 if all_elements_provide('typicaldist')
0158 M.typicaldist = @typicaldist;
0159 end
0160 function d = typicaldist
0161 sqd = 0;
0162 for i = 1 : nelems
0163 sqd = sqd + elements.(elems{i}).typicaldist()^2;
0164 end
0165 d = sqrt(sqd);
0166 end
0167
0168 M.proj = @proj;
0169 function v = proj(x, u)
0170 for i = 1 : nelems
0171 v.(elems{i}) = elements.(elems{i}).proj(x.(elems{i}), ...
0172 u.(elems{i}));
0173 end
0174 end
0175
0176 M.tangent = @tangent;
0177 function v = tangent(x, u)
0178 for i = 1 : nelems
0179 v.(elems{i}) = elements.(elems{i}).tangent(x.(elems{i}), ...
0180 u.(elems{i}));
0181 end
0182 end
0183
0184
0185 M.tangent2ambient_is_identity = true;
0186 for k = 1 : nelems
0187 if isfield(elements.(elems{k}), 'tangent2ambient_is_identity')
0188 if ~elements.(elems{k}).tangent2ambient_is_identity
0189 M.tangent2ambient_is_identity = false;
0190 break;
0191 end
0192 end
0193 end
0194
0195 M.tangent2ambient = @tangent2ambient;
0196 function v = tangent2ambient(x, u)
0197 for i = 1 : nelems
0198 if isfield(elements.(elems{i}), 'tangent2ambient')
0199 v.(elems{i}) = ...
0200 elements.(elems{i}).tangent2ambient( ...
0201 x.(elems{i}), u.(elems{i}));
0202 else
0203 v.(elems{i}) = u.(elems{i});
0204 end
0205 end
0206 end
0207
0208 M.egrad2rgrad = @egrad2rgrad;
0209 function g = egrad2rgrad(x, g)
0210 for i = 1 : nelems
0211 g.(elems{i}) = elements.(elems{i}).egrad2rgrad(...
0212 x.(elems{i}), g.(elems{i}));
0213 end
0214 end
0215 for ii = 1 : nelems
0216 if nargin(elements.(elems{ii}).egrad2rgrad) > 2
0217 warning('manopt:productmanifold:egrad2rgrad', ...
0218 ['Product manifolds call M.egrad2rgrad with only two ', ...
0219 'inputs:\nstoredb and key won''t be available.']);
0220 break;
0221 end
0222 end
0223
0224 M.ehess2rhess = @ehess2rhess;
0225 function h = ehess2rhess(x, eg, eh, h)
0226 for i = 1 : nelems
0227 h.(elems{i}) = elements.(elems{i}).ehess2rhess(...
0228 x.(elems{i}), eg.(elems{i}), eh.(elems{i}), h.(elems{i}));
0229 end
0230 end
0231 for ii = 1 : nelems
0232 if nargin(elements.(elems{ii}).ehess2rhess) > 4
0233 warning('manopt:productmanifold:ehess2rhess', ...
0234 ['Product manifolds call M.ehess2rhess with only two ', ...
0235 'inputs:\nstoredb and key won''t be available.']);
0236 break;
0237 end
0238 end
0239
0240 if all_elements_provide('exp')
0241 M.exp = @exp;
0242 end
0243 function y = exp(x, u, t)
0244 if nargin < 3
0245 t = 1.0;
0246 end
0247 for i = 1 : nelems
0248 y.(elems{i}) = elements.(elems{i}).exp(x.(elems{i}), ...
0249 u.(elems{i}), t);
0250 end
0251 end
0252
0253 M.retr = @retr;
0254 function y = retr(x, u, t)
0255 if nargin < 3
0256 t = 1.0;
0257 end
0258 for i = 1 : nelems
0259 y.(elems{i}) = elements.(elems{i}).retr(x.(elems{i}), ...
0260 u.(elems{i}), t);
0261 end
0262 end
0263
0264 if all_elements_provide('log')
0265 M.log = @log;
0266 end
0267 function u = log(x1, x2)
0268 for i = 1 : nelems
0269 u.(elems{i}) = elements.(elems{i}).log(x1.(elems{i}), ...
0270 x2.(elems{i}));
0271 end
0272 end
0273
0274 M.hash = @hash;
0275 function str = hash(x)
0276 str = '';
0277 for i = 1 : nelems
0278 str = [str elements.(elems{i}).hash(x.(elems{i}))];
0279 end
0280 str = ['z' hashmd5(str)];
0281 end
0282
0283 M.lincomb = @lincomb;
0284 function v = lincomb(x, a1, u1, a2, u2)
0285 if nargin == 3
0286 for i = 1 : nelems
0287 v.(elems{i}) = elements.(elems{i}).lincomb(x.(elems{i}), ...
0288 a1, u1.(elems{i}));
0289 end
0290 elseif nargin == 5
0291 for i = 1 : nelems
0292 v.(elems{i}) = elements.(elems{i}).lincomb(x.(elems{i}), ...
0293 a1, u1.(elems{i}), a2, u2.(elems{i}));
0294 end
0295 else
0296 error('Bad usage of productmanifold.lincomb');
0297 end
0298 end
0299
0300 M.rand = @rand;
0301 function x = rand()
0302 for i = 1 : nelems
0303 x.(elems{i}) = elements.(elems{i}).rand();
0304 end
0305 end
0306
0307 M.randvec = @randvec;
0308 function u = randvec(x)
0309 for i = 1 : nelems
0310 u.(elems{i}) = elements.(elems{i}).randvec(x.(elems{i}));
0311 end
0312 u = M.lincomb(x, 1/sqrt(nelems), u);
0313 end
0314
0315 M.zerovec = @zerovec;
0316 function u = zerovec(x)
0317 for i = 1 : nelems
0318 u.(elems{i}) = elements.(elems{i}).zerovec(x.(elems{i}));
0319 end
0320 end
0321
0322 if all_elements_provide('transp')
0323 M.transp = @transp;
0324 end
0325 function v = transp(x1, x2, u)
0326 for i = 1 : nelems
0327 v.(elems{i}) = elements.(elems{i}).transp(x1.(elems{i}), ...
0328 x2.(elems{i}), u.(elems{i}));
0329 end
0330 end
0331
0332 if all_elements_provide('pairmean')
0333 M.pairmean = @pairmean;
0334 end
0335 function y = pairmean(x1, x2)
0336 for i = 1 : nelems
0337 y.(elems{i}) = elements.(elems{i}).pairmean(x1.(elems{i}), ...
0338 x2.(elems{i}));
0339 end
0340 end
0341
0342 if vec_available
0343 M.vec = @vec;
0344 M.mat = @mat;
0345 end
0346
0347 function u_vec = vec(x, u_mat)
0348 u_vec = zeros(vec_pos(end)-1, 1);
0349 for i = 1 : nelems
0350 range = vec_pos(i) : (vec_pos(i+1)-1);
0351 u_vec(range) = elements.(elems{i}).vec(x.(elems{i}), ...
0352 u_mat.(elems{i}));
0353 end
0354 end
0355
0356 function u_mat = mat(x, u_vec)
0357 u_mat = struct();
0358 for i = 1 : nelems
0359 range = vec_pos(i) : (vec_pos(i+1)-1);
0360 u_mat.(elems{i}) = elements.(elems{i}).mat(x.(elems{i}), ...
0361 u_vec(range));
0362 end
0363 end
0364
0365 M.vecmatareisometries = @() vecmatareisometries;
0366
0367 if all_elements_provide('lie_identity')
0368 M.lie_identity = @lie_identity;
0369 end
0370
0371 function I = lie_identity()
0372 I = struct();
0373 for i = 1 : nelems
0374 Mi = elements.(elems{i});
0375 Ii = Mi.lie_identity();
0376 I.(elems{i}) = Ii;
0377 end
0378 end
0379 end