0001
0002
0003
0004
0005
0006 function X = construct_initial_guess(L, F, r, n)
0007
0008
0009 X = TTeMPS_rand( r, n );
0010 X = 1/norm(X) * X;
0011
0012
0013 d = X.order;
0014 n = X.size;
0015
0016
0017 X = orthogonalize(X, 1);
0018 Fi = contract( X, F, 1 );
0019 sz = [X.rank(1), X.size(1), X.rank(2)];
0020
0021 [left, right] = Afun_prepare( L, X, 1 );
0022 expB = constr_precond_inner( L, X, 1 );
0023 Ui = pcg( @(y) Afun( L, y, 1, sz, left, right), ...
0024 Fi(:), ...
0025 1e-6, 1000, ...
0026 @(y) apply_local_precond( L, y, sz, expB ), [],...
0027 X.U{1}(:) );
0028
0029
0030 X.U{1} = reshape( Ui, size(X.U{1}) );
0031
0032 X = orth_at( X, 1, 'left', true );
0033
0034
0035
0036 end
0037
0038
0039
0040 function [left, right] = Afun_prepare( A, x, idx )
0041 y = A.apply(x);
0042 if idx == 1
0043 right = innerprod( x, y, 'RL', idx+1 );
0044 left = [];
0045 elseif idx == x.order
0046 left = innerprod( x, y, 'LR', idx-1 );
0047 right = [];
0048 else
0049 left = innerprod( x, y, 'LR', idx-1 );
0050 right = innerprod( x, y, 'RL', idx+1 );
0051 end
0052 end
0053
0054 function res = Afun( A, U, idx, sz, left, right )
0055
0056 V = reshape( U, sz );
0057 V = A.apply( V, idx );
0058
0059 if idx == 1
0060 tmp = tensorprod_ttemps( V, right, 3 );
0061 elseif idx == A.order
0062 tmp = tensorprod_ttemps( V, left, 1 );
0063 else
0064 tmp = tensorprod_ttemps( V, right, 3);
0065 tmp = tensorprod_ttemps( tmp, left, 1);
0066 end
0067
0068 res = tmp(:);
0069 end
0070
0071 function res = apply_local_precond( A, U, sz, expB)
0072
0073 p = size(U, 2);
0074
0075 x = reshape( U, [sz, p] );
0076 res = zeros( [sz, p] );
0077
0078 for i = 1:size( expB, 1)
0079 tmp = reshape( x, [sz(1), sz(2)*sz(3)*p] );
0080 tmp = reshape( expB{1,i}*tmp, [sz(1), sz(2), sz(3), p] );
0081
0082 tmp = reshape( permute( tmp, [2 1 3 4] ), [sz(2), sz(1)*sz(3)*p] );
0083 tmp = ipermute( reshape( expB{2,i}*tmp, [sz(2), sz(1), sz(3), p] ), [2 1 3 4] );
0084
0085 tmp = reshape( permute( tmp, [3 1 2 4] ), [sz(3), sz(1)*sz(2)*p] );
0086 tmp = ipermute( reshape( expB{3,i}*tmp, [sz(3), sz(1), sz(2), p] ), [3 1 2 4] );
0087
0088 res = res + tmp;
0089 end
0090 res = reshape( res, [prod(sz), p] );
0091
0092 end
0093
0094
0095