0001
0002
0003
0004
0005 function [X, residuum, cost, times] = alsLinsolve_fast( L, F, X, opts )
0006
0007 t_start = tic();
0008
0009 if ~exist( 'opts', 'var'); opts = struct(); end
0010 if ~isfield( opts, 'nSweeps'); opts.nSweeps = 4; end
0011 if ~isfield( opts, 'solver'); opts.solver = 'pcg'; end
0012
0013 d = X.order;
0014 n = X.size;
0015
0016
0017 normF = norm(F);
0018 g = apply(L, X) - F;
0019 cost = cost_function_res( X, g );
0020 residuum = norm( g ) / normF;
0021 times = toc(t_start);
0022
0023 X = orthogonalize(X, 1);
0024 for sweep = 1:opts.nSweeps
0025
0026
0027
0028 disp( ['STARTING SWEEP ', num2str(sweep), ' from left to right'] )
0029 disp( '===========================================================')
0030 for idx = 1:d-1
0031 disp( ['Current core: ', num2str(idx)] )
0032
0033 Fi = contract( X, F, idx );
0034 sz = [X.rank(idx), X.size(idx), X.rank(idx+1)];
0035
0036 if strcmpi( opts.solver, 'direct' )
0037
0038 Li = contract( L, X, idx );
0039 Ui = Li \ Fi(:);
0040 X.U{idx} = reshape( Ui, sz );
0041
0042 elseif strcmpi( opts.solver, 'pcg' )
0043
0044 [left, right] = Afun_prepare( L, X, idx );
0045 [B2, V, E] = prepare_precond( L.L0, X, idx );
0046
0047 Ui = pcg( @(y) Afun( L, y, idx, sz, left, right), ...
0048 Fi(:), ...
0049 1e-10, 1000, ...
0050 @(y) apply_precond( B2, V, E, y, sz ), [],...
0051 X.U{idx}(:) );
0052
0053 X.U{idx} = reshape( Ui, sz );
0054
0055 elseif strcmpi( opts.solver, 'diag' )
0056 X.U{idx} = solve_inner( L.L0, X, Fi, idx );
0057
0058 else
0059 error( 'Unknown opts.solver type. Use either ''direct'', ''pcg'' (default) or ''diag''.' )
0060 end
0061
0062 X = orth_at( X, idx, 'left', true );
0063
0064 g = apply(L, X) - F;
0065 residuum = [residuum; norm( g ) / normF];
0066 cost = [cost; cost_function_res( X, g )];
0067 times = [times; toc(t_start)];
0068 end
0069
0070
0071
0072
0073 disp( 'Starting right-to-left half-sweep:')
0074 for idx = d:-1:2
0075 disp( ['Current core: ', num2str(idx)] )
0076
0077 Fi = contract( X, F, idx );
0078 sz = [X.rank(idx), X.size(idx), X.rank(idx+1)];
0079
0080 if strcmpi( opts.solver, 'direct' )
0081
0082 Li = contract( L, X, idx );
0083 Ui = Li \ Fi(:);
0084 X.U{idx} = reshape( Ui, sz );
0085
0086 elseif strcmpi( opts.solver, 'pcg' )
0087
0088 [left, right] = Afun_prepare( L, X, idx );
0089 [B2, V, E] = prepare_precond( L.L0, X, idx );
0090
0091 Ui = pcg( @(y) Afun( L, y, idx, sz, left, right), ...
0092 Fi(:), ...
0093 1e-10, 1000, ...
0094 @(y) apply_precond( B2, V, E, y, sz ), [],...
0095 X.U{idx}(:) );
0096
0097
0098
0099 X.U{idx} = reshape( Ui, sz );
0100
0101 elseif strcmpi( opts.solver, 'diag' )
0102 X.U{idx} = solve_inner( L.L0, X, Fi, idx );
0103
0104 else
0105 error( 'Unknown opts.solver type. Use either ''direct'', ''pcg'' (default) or ''diag''.' )
0106 end
0107
0108
0109 X = orth_at( X, idx, 'right', true );
0110
0111 g = apply(L, X) - F;
0112 residuum = [residuum; norm( g ) / normF];
0113 cost = [cost; cost_function_res( X, g )];
0114 times = [times; toc(t_start)];
0115 end
0116
0117 end
0118
0119
0120 end
0121
0122 function res = cost_function( L, X, F )
0123 res = 0.5*innerprod( X, apply(L, X) ) - innerprod( X, F );
0124 end
0125
0126 function res = cost_function_res( X, res )
0127 res = 0.5*innerprod( X, res );
0128 end
0129
0130
0131 function [left, right] = Afun_prepare( A, x, idx )
0132 y = A.apply(x);
0133 if idx == 1
0134 right = innerprod( x, y, 'RL', idx+1 );
0135 left = [];
0136 elseif idx == x.order
0137 left = innerprod( x, y, 'LR', idx-1 );
0138 right = [];
0139 else
0140 left = innerprod( x, y, 'LR', idx-1 );
0141 right = innerprod( x, y, 'RL', idx+1 );
0142 end
0143 end
0144
0145 function res = Afun( A, U, idx, sz, left, right )
0146
0147 V = reshape( U, sz );
0148 V = A.apply( V, idx );
0149
0150 if idx == 1
0151 tmp = tensorprod_ttemps( V, right, 3 );
0152 elseif idx == A.order
0153 tmp = tensorprod_ttemps( V, left, 1 );
0154 else
0155 tmp = tensorprod_ttemps( V, right, 3);
0156 tmp = tensorprod_ttemps( tmp, left, 1);
0157 end
0158
0159 res = tmp(:);
0160 end
0161
0162
0163
0164
0165
0166
0167
0168
0169
0170
0171
0172
0173
0174
0175
0176
0177
0178
0179
0180
0181
0182
0183
0184
0185 function res = solve_inner( L0, X, Fi, idx )
0186 n = size(L0, 1);
0187 rl = X.rank(idx);
0188 rr = X.rank(idx+1);
0189
0190 B1 = zeros( rl );
0191
0192 for i = 1:idx-1
0193
0194 tmp = X;
0195 tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0196 B1 = B1 + innerprod( X, tmp, 'LR', idx-1);
0197 end
0198
0199
0200 B2 = L0;
0201
0202 B3 = zeros( rr );
0203
0204 for i = idx+1:X.order
0205 tmp = X;
0206 tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0207 B3 = B3 + innerprod( X, tmp, 'RL', idx+1);
0208 end
0209
0210 [V,E] = eig( kron( eye(rr), B1 ) + kron( B3, eye(rl) ) );
0211 E = diag(E);
0212
0213 rhs = matricize( Fi, 2 ) * V;
0214 Y = zeros(size(rhs));
0215 for i=1:length(E)
0216 Y(:,i) = (B2 + E(i)*speye(n)) \ rhs(:,i);
0217 end
0218 res = tensorize( Y*V', 2, [rl, n, rr] );
0219 end
0220
0221 function [B2, V, E] = prepare_precond( L0, X, idx )
0222 n = size(L0, 1);
0223 rl = X.rank(idx);
0224 rr = X.rank(idx+1);
0225
0226 B1 = zeros( rl );
0227
0228 for i = 1:idx-1
0229
0230 tmp = X;
0231 tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0232 B1 = B1 + innerprod( X, tmp, 'LR', idx-1);
0233 end
0234
0235
0236 B2 = L0;
0237
0238 B3 = zeros( rr );
0239
0240 for i = idx+1:X.order
0241 tmp = X;
0242 tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0243 B3 = B3 + innerprod( X, tmp, 'RL', idx+1);
0244 end
0245
0246 [V,E] = eig( kron( eye(rr), B1 ) + kron( B3, eye(rl) ) );
0247 E = diag(E);
0248 end
0249
0250 function res = apply_precond( B2, V, E, rhs, sz )
0251 n = size(B2, 1);
0252 rhs = reshape( rhs, sz );
0253 rhs = matricize( rhs, 2 ) * V;
0254 Y = zeros(size(rhs));
0255 for i=1:length(E)
0256 Y(:,i) = (B2 + E(i)*speye(n)) \ rhs(:,i);
0257 end
0258 res = tensorize( Y*V', 2, sz );
0259 res = res(:);
0260 end
0261
0262
0263