0001
0002
0003
0004
0005 function [X, residuum, cost, times] = alsLinsolve_rankOne( L, F, X, opts )
0006
0007 t_start = tic();
0008
0009 if ~exist( 'opts', 'var'); opts = struct(); end
0010 if ~isfield( opts, 'nSweeps'); opts.nSweeps = 4; end
0011 if ~isfield( opts, 'solver'); opts.solver = 'pcg'; end
0012
0013 d = X.order;
0014 n = X.size;
0015
0016
0017 normF = norm(F);
0018 g = apply(L, X) - F;
0019 cost = cost_function_res( X, g );
0020 residuum = norm( g ) / normF;
0021 times = toc(t_start);
0022
0023 X = orthogonalize(X, 1);
0024 for sweep = 1:opts.nSweeps
0025
0026
0027
0028 disp( ['STARTING SWEEP ', num2str(sweep), ' from left to right'] )
0029 disp( '===========================================================')
0030 for idx = 1:d-1
0031 disp( ['Current core: ', num2str(idx)] )
0032
0033 Fi = contract( X, F, idx );
0034 sz = [X.rank(idx), X.size(idx), X.rank(idx+1)];
0035
0036 if strcmpi( opts.solver, 'direct' )
0037
0038 Li = contract( X, apply(L,X), idx );
0039 Ui = Li \ Fi(:);
0040 X.U{idx} = reshape( Ui, sz );
0041
0042 elseif strcmpi( opts.solver, 'pcg' )
0043
0044 [left, right] = Afun_prepare( L, X, idx );
0045 B1 = prepare_precond( L.A{1}, X, idx );
0046
0047 Ui = pcg( @(y) Afun( L, y, idx, sz, left, right), ...
0048 Fi(:), ...
0049 1e-10, 1000, ...
0050 @(y) apply_precond( L.A{1}, B1, y, sz ), [],...
0051 X.U{idx}(:) );
0052
0053 X.U{idx} = reshape( Ui, sz );
0054
0055 elseif strcmpi( opts.solver, 'diag' )
0056 X.U{idx} = solve_inner( L.L0, X, Fi, idx );
0057
0058 else
0059 error( 'Unknown opts.solver type. Use either ''direct'', ''pcg'' (default) or ''diag''.' )
0060 end
0061
0062 X = orth_at( X, idx, 'left', true );
0063
0064 g = apply(L, X) - F;
0065 residuum = [residuum; norm( g ) / normF];
0066 cost = [cost; cost_function_res( X, g )];
0067 times = [times; toc(t_start)];
0068 end
0069
0070
0071
0072
0073 disp( 'Starting right-to-left half-sweep:')
0074 for idx = d:-1:2
0075 disp( ['Current core: ', num2str(idx)] )
0076
0077 Fi = contract( X, F, idx );
0078 sz = [X.rank(idx), X.size(idx), X.rank(idx+1)];
0079
0080 if strcmpi( opts.solver, 'direct' )
0081
0082 Li = contract( X, apply(L, X), idx );
0083 Ui = Li \ Fi(:);
0084 X.U{idx} = reshape( Ui, sz );
0085
0086 elseif strcmpi( opts.solver, 'pcg' )
0087
0088 [left, right] = Afun_prepare( L, X, idx );
0089 B1 = prepare_precond( L.A{1}, X, idx );
0090
0091 Ui = pcg( @(y) Afun( L, y, idx, sz, left, right), ...
0092 Fi(:), ...
0093 1e-10, 1000, ...
0094 @(y) apply_precond( L.A{1}, B1, y, sz ), [],...
0095 X.U{idx}(:) );
0096
0097
0098
0099 X.U{idx} = reshape( Ui, sz );
0100
0101 elseif strcmpi( opts.solver, 'diag' )
0102 X.U{idx} = solve_inner( L.L0, X, Fi, idx );
0103
0104 else
0105 error( 'Unknown opts.solver type. Use either ''direct'', ''pcg'' (default) or ''diag''.' )
0106 end
0107
0108
0109 X = orth_at( X, idx, 'right', true );
0110
0111 g = apply(L, X) - F;
0112 residuum = [residuum; norm( g ) / normF];
0113 cost = [cost; cost_function_res( X, g )];
0114 times = [times; toc(t_start)];
0115 end
0116
0117 end
0118
0119
0120 end
0121
0122 function res = cost_function( L, X, F )
0123 res = 0.5*innerprod( X, apply(L, X) ) - innerprod( X, F );
0124 end
0125
0126 function res = cost_function_res( X, res )
0127 res = 0.5*innerprod( X, res );
0128 end
0129
0130
0131 function [left, right] = Afun_prepare( A, x, idx )
0132 y = A.apply(x);
0133 if idx == 1
0134 right = innerprod( x, y, 'RL', idx+1 );
0135 left = [];
0136 elseif idx == x.order
0137 left = innerprod( x, y, 'LR', idx-1 );
0138 right = [];
0139 else
0140 left = innerprod( x, y, 'LR', idx-1 );
0141 right = innerprod( x, y, 'RL', idx+1 );
0142 end
0143 end
0144
0145 function res = Afun( A, U, idx, sz, left, right )
0146
0147 V = reshape( U, sz );
0148 V = A.apply( V, idx );
0149
0150 if idx == 1
0151 tmp = tensorprod_ttemps( V, right, 3 );
0152 elseif idx == A.order
0153 tmp = tensorprod_ttemps( V, left, 1 );
0154 else
0155 tmp = tensorprod_ttemps( V, right, 3);
0156 tmp = tensorprod_ttemps( tmp, left, 1);
0157 end
0158
0159 res = tmp(:);
0160 end
0161
0162
0163 function B1 = prepare_precond( L0, X, idx )
0164
0165 if idx == 1
0166 B1 = [];
0167 return
0168 end
0169
0170 n = size(L0, 1);
0171 r = X.rank;
0172
0173 X1 = matricize( X.U{1}, 2);
0174 Y = X;
0175 Y.U{1} = tensorize( L0*X1, 2, [r(1), n(1), r(2)] );
0176 B1 = innerprod( X, Y, 'LR', idx-1);
0177 end
0178
0179 function res = apply_precond( L0, B1, rhs, sz )
0180
0181 n = size(L0, 1);
0182 rhs = reshape( rhs, sz );
0183 if isempty(B1)
0184 res = L0 \ unfold( rhs, 'left' );
0185 res = reshape( res, sz );
0186 else
0187 res = B1 \ unfold(rhs, 'right');
0188 res = reshape( res, sz );
0189 end
0190 res = res(:);
0191 end
0192
0193
0194