0001
0002
0003
0004
0005
0006 function X = construct_initial_guess(L, F, r, n)
0007
0008
0009 X = TTeMPS_rand( r, n );
0010 X = 1/norm(X) * X;
0011
0012
0013 d = X.order;
0014 n = X.size;
0015
0016
0017 X = orthogonalize(X, 1);
0018 Fi = contract( X, F, 1 );
0019 sz = [X.rank(1), X.size(1), X.rank(2)];
0020
0021 [left, right] = Afun_prepare( L, X, 1 );
0022 B1 = prepare_precond( L.A{1}, X, 1 );
0023
0024 Ui = pcg( @(y) Afun( L, y, 1, sz, left, right), ...
0025 Fi(:), ...
0026 1e-10, 1000, ...
0027 @(y) apply_precond( L.A{1}, B1, y, sz ), [],...
0028 X.U{1}(:) );
0029
0030 X.U{1} = reshape( Ui, sz );
0031
0032 X = orth_at( X, 1, 'left', true );
0033
0034
0035 end
0036
0037
0038
0039
0040 function [left, right] = Afun_prepare( A, x, idx )
0041 y = A.apply(x);
0042 if idx == 1
0043 right = innerprod( x, y, 'RL', idx+1 );
0044 left = [];
0045 elseif idx == x.order
0046 left = innerprod( x, y, 'LR', idx-1 );
0047 right = [];
0048 else
0049 left = innerprod( x, y, 'LR', idx-1 );
0050 right = innerprod( x, y, 'RL', idx+1 );
0051 end
0052 end
0053
0054 function res = Afun( A, U, idx, sz, left, right )
0055
0056 V = reshape( U, sz );
0057 V = A.apply( V, idx );
0058
0059 if idx == 1
0060 tmp = tensorprod_ttemps( V, right, 3 );
0061 elseif idx == A.order
0062 tmp = tensorprod_ttemps( V, left, 1 );
0063 else
0064 tmp = tensorprod_ttemps( V, right, 3);
0065 tmp = tensorprod_ttemps( tmp, left, 1);
0066 end
0067
0068 res = tmp(:);
0069 end
0070
0071
0072 function B1 = prepare_precond( L0, X, idx )
0073
0074 if idx == 1
0075 B1 = [];
0076 return
0077 end
0078
0079 n = size(L0, 1);
0080 r = X.rank;
0081
0082 X1 = matricize( X.U{1}, 2);
0083 Y = X;
0084 Y.U{1} = tensorize( L0*X1, 2, [r(1), n(1), r(2)] );
0085 B1 = innerprod( X, Y, 'LR', idx-1);
0086 end
0087
0088 function res = apply_precond( L0, B1, rhs, sz )
0089
0090 n = size(L0, 1);
0091 rhs = reshape( rhs, sz );
0092 if isempty(B1)
0093 res = L0 \ unfold( rhs, 'left' );
0094 res = reshape( res, sz );
0095 else
0096 res = B1 \ unfold(rhs, 'right');
0097 res = reshape( res, sz );
0098 end
0099 res = res(:);
0100 end
0101
0102
0103
0104
0105
0106