0001
0002
0003
0004
0005
0006 function [eta] = precond_laplace_noSaddle( L, xi, xL, xR )
0007
0008 eta = xi;
0009
0010 r = xi.rank;
0011 n = xi.size;
0012 d = xi.order;
0013
0014 if nargin == 2
0015 xL = TTeMPS(xi.U);
0016 xR = TTeMPS(xi.V);
0017 end
0018
0019
0020
0021
0022 A = cell(1, d);
0023 M = cell(1, d);
0024 A{1} = L{1};
0025 for i = 2:d
0026 M{i} = unfold( xi.U{i-1}, 'left')' * A{i-1} * unfold( xi.U{i-1}, 'left');
0027 A{i} = kron( speye(n(i)), M{i} ) + kron( L{i}, eye(r(i)) );
0028 end
0029
0030
0031
0032 B = cell(1, d-1);
0033 B{d-1} = unfold( xi.V{d}, 'right' ) * kron( speye(r(d+1)), L{d} ) ...
0034 * unfold( xi.V{d}, 'right' )';
0035
0036 for i = d-2:-1:1
0037 B{i} = unfold( xi.V{i+1}, 'right' ) * ...
0038 ( kron( speye(r(i+2)), L{i+1} ) + kron( B{i+1}, speye(n(i+1)) ) ) ...
0039 * unfold( xi.V{i+1}, 'right' )';
0040 end
0041
0042
0043 [Q, lam] = eig(B{1}); lam = diag(lam);
0044 dU1l_eta = zeros( [n(1), r(2)] );
0045 U1lQ = unfold( xi.U{1}, 'left') * Q;
0046 dU1l_xi = unfold( xi.dU{1}, 'left' ) * Q;
0047 for i = 1:r(2)
0048 dU1l_eta(:,i) = solve_saddle( A{1}, lam(i), U1lQ, dU1l_xi(:,i) );
0049 end
0050 eta.dU{1} = reshape( dU1l_eta*Q', size(xi.dU{1}) );
0051
0052
0053 for i = 2:d-1
0054 [Q, lam] = eig(B{i}); lam = diag(lam);
0055 dUl_eta = zeros( [r(i)*n(i), r(i+1)] );
0056 UQ = reshape( unfold( xi.U{i}, 'left') * Q, size(xi.U{i}) );
0057 dUQ_xi = reshape( unfold( xi.dU{i}, 'left') * Q, size(xi.dU{i}) );
0058 for j = 1:r(i+1)
0059 dUl_eta(:,j) = solve_saddle_fast( L{i}, M{i}, lam(j), ...
0060 UQ, dUQ_xi(:,:,j));
0061 end
0062 eta.dU{i} = reshape( dUl_eta*Q', size(xi.dU{i}) );
0063 end
0064
0065
0066 [Q, gam] = eig( M{d} );
0067 gam = diag(gam);
0068 eta.dU{d} = solve_kron( L{d}, 0, Q, gam, xi.dU{d} );
0069
0070
0071 eta = TTeMPS_tangent_orth( xL, xR, eta );
0072
0073
0074
0075 end
0076
0077
0078 function res = solve_saddle( A, lam, Ul, rhs )
0079
0080 As = (A + lam*speye(size(A)));
0081
0082 res = As \ rhs;
0083
0084 end
0085
0086 function res = solve_saddle_fast( A, M, lam, U, rhs )
0087
0088 [Q, gam] = eig(M);
0089 gam = diag(gam);
0090
0091
0092
0093 d = unfold(rhs, 'left');
0094 d = reshape( d, size(rhs) );
0095 res = solve_kron( A, lam, Q, gam, d );
0096 res = unfold( res, 'left');
0097 end
0098
0099
0100 function sol = solve_kron( A, lam, Q, gam, rhs )
0101
0102 if size(rhs, 3) == 1
0103 rhst_2 = rhs.' * Q;
0104 else
0105 rhst_2 = matricize(rhs, 2) * kron( eye(size(rhs,3)), Q );
0106 end
0107
0108 solt_2 = zeros(size(rhst_2));
0109 for i=1:length(gam)
0110 solt_2(:, i:length(gam):end) = ( A + (lam + gam(i))*speye(size(A)) ) ...
0111 \ rhst_2(:,i:length(gam):end);
0112 end
0113
0114 if size(rhs, 3) == 1
0115 sol_2 = solt_2 * Q';
0116 sol = sol_2.';
0117 else
0118 sol_2 = solt_2 * kron( eye(size(rhs,3)), Q' );
0119 sol = tensorize( sol_2, 2, size(rhs) );
0120 end
0121
0122 end
0123