0001 function M = complexcirclefactory(n, m, gpuflag)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049 if ~exist('n', 'var') || isempty(n)
0050 n = 1;
0051 end
0052 if ~exist('m', 'var') || isempty(m)
0053 m = 1;
0054 end
0055 if ~exist('gpuflag', 'var') || isempty(gpuflag)
0056 gpuflag = false;
0057 end
0058
0059
0060
0061
0062 if gpuflag
0063 array_type = 'gpuArray';
0064 else
0065 array_type = 'double';
0066 end
0067
0068 if m == 1
0069 M.name = @() sprintf('Complex circle (S^1)^%d', n);
0070 else
0071 M.name = @() sprintf('Complex circle (S^1)^(%dx%d)', n, m);
0072 end
0073
0074 M.dim = @() n*m;
0075
0076 M.inner = @(z, v, w) real(v(:)'*w(:));
0077
0078 M.norm = @(x, v) norm(v, 'fro');
0079
0080 M.dist = @(x, y) norm(real(2*asin(.5*abs(x - y))), 'fro');
0081
0082 M.typicaldist = @() pi*sqrt(n*m);
0083
0084 M.proj = @(z, u) u - real( conj(u) .* z ) .* z;
0085
0086 M.tangent = M.proj;
0087
0088
0089
0090 M.egrad2rgrad = M.proj;
0091
0092 M.ehess2rhess = @ehess2rhess;
0093 function rhess = ehess2rhess(z, egrad, ehess, zdot)
0094 rhess = M.proj(z, ehess - real(z.*conj(egrad)).*zdot);
0095 end
0096
0097 M.exp = @exponential;
0098 function y = exponential(z, v, t)
0099
0100 if nargin == 2
0101
0102 tv = v;
0103 else
0104 tv = t*v;
0105 end
0106
0107 y = zeros(n, m, array_type);
0108
0109 nrm_tv = abs(tv);
0110
0111
0112 mask = (nrm_tv > 0);
0113 y(mask) = z(mask).*cos(nrm_tv(mask)) + ...
0114 tv(mask).*(sin(nrm_tv(mask))./nrm_tv(mask));
0115 y(~mask) = z(~mask);
0116
0117 end
0118
0119 M.retr = @retraction;
0120 function y = retraction(z, v, t)
0121 if nargin == 2
0122
0123 tv = v;
0124 else
0125 tv = t*v;
0126 end
0127 y = sign(z+tv);
0128 end
0129
0130 M.invretr = @inverse_retraction;
0131 function v = inverse_retraction(x, y)
0132 v = y ./ real(conj(x) .* y) - x;
0133 end
0134
0135 M.log = @logarithm;
0136 function v = logarithm(x1, x2)
0137 v = M.proj(x1, x2 - x1);
0138 di = real(2*asin(.5*abs(x1 - x2)));
0139 nv = abs(v);
0140 factors = di ./ nv;
0141 factors(di <= 1e-10) = 1;
0142 v = v .* factors;
0143 end
0144
0145 M.hash = @(z) ['z' hashmd5( [real(z(:)) ; imag(z(:))] ) ];
0146
0147 M.rand = @random;
0148 function z = random()
0149 z = sign(randn(n, m, array_type) + 1i*randn(n, m, array_type));
0150 end
0151
0152 M.randvec = @randomvec;
0153 function v = randomvec(z)
0154
0155 v = randn(n, m, array_type) .* (1i*z);
0156 v = v / norm(v, 'fro');
0157 end
0158
0159 M.lincomb = @matrixlincomb;
0160
0161 M.zerovec = @(x) zeros(n, m, array_type);
0162
0163 M.transp = @(x1, x2, d) M.proj(x2, d);
0164
0165 M.pairmean = @pairmean;
0166 function z = pairmean(z1, z2)
0167 z = sign(z1+z2);
0168 end
0169
0170 M.vec = @(x, u_mat) [real(u_mat(:)) ; imag(u_mat(:))];
0171 M.mat = @(x, u_vec) reshape(u_vec(1:(n*m)) + 1i*u_vec((n*m+1):end), [n, m]);
0172 M.vecmatareisometries = @() true;
0173
0174
0175 if gpuflag
0176 M = factorygpuhelper(M);
0177 end
0178
0179 M.lie_identity = @() ones(n, m, array_type);
0180
0181 end