Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > algorithms > linearsystem > alsLinsolve_fast.m

alsLinsolve_fast

PURPOSE ^

TTeMPS Toolbox.

SYNOPSIS ^

function [X, residuum, cost, times] = alsLinsolve_fast( L, F, X, opts )

DESCRIPTION ^

   TTeMPS Toolbox. 
   Michael Steinlechner, 2013-2016
   Questions and contact: michael.steinlechner@epfl.ch
   BSD 2-clause license, see LICENSE.txt

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 %   TTeMPS Toolbox.
0002 %   Michael Steinlechner, 2013-2016
0003 %   Questions and contact: michael.steinlechner@epfl.ch
0004 %   BSD 2-clause license, see LICENSE.txt
0005 function [X, residuum, cost, times] = alsLinsolve_fast( L, F, X, opts )
0006 
0007 t_start = tic();
0008 % set default opts
0009 if ~exist( 'opts', 'var');       opts = struct();       end
0010 if ~isfield( opts, 'nSweeps');   opts.nSweeps = 4;      end
0011 if ~isfield( opts, 'solver');    opts.solver = 'pcg';   end
0012 
0013 d = X.order;
0014 n = X.size;
0015 
0016 
0017 normF = norm(F);
0018 g = apply(L, X) - F;
0019 cost = cost_function_res( X, g );
0020 residuum = norm( g ) / normF;
0021 times = toc(t_start);
0022 
0023 X = orthogonalize(X, 1);
0024 for sweep = 1:opts.nSweeps
0025     % ====================================================================
0026     % LEFT-TO-RIGHT SWEEP
0027     % ====================================================================
0028     disp( ['STARTING SWEEP ', num2str(sweep), ' from left to right'] )
0029     disp( '===========================================================')
0030     for idx = 1:d-1
0031         disp( ['Current core: ', num2str(idx)] )
0032 
0033         Fi = contract( X, F, idx );
0034         sz = [X.rank(idx), X.size(idx), X.rank(idx+1)];
0035         
0036         if strcmpi( opts.solver, 'direct' )
0037             % if system very small
0038             Li = contract( L, X, idx );
0039             Ui = Li \ Fi(:);
0040             X.U{idx} = reshape( Ui, sz );
0041 
0042         elseif strcmpi( opts.solver, 'pcg' )
0043 
0044             [left, right] = Afun_prepare( L, X, idx );
0045             [B2, V, E] =  prepare_precond( L.L0, X, idx );
0046 
0047             Ui = pcg( @(y) Afun( L, y, idx, sz, left, right), ...
0048                      Fi(:), ...
0049                      1e-10, 1000, ...
0050                      @(y) apply_precond( B2, V, E, y, sz ), [],...
0051                      X.U{idx}(:) ); 
0052 
0053             X.U{idx} = reshape( Ui, sz );
0054 
0055         elseif strcmpi( opts.solver, 'diag' )
0056             X.U{idx} = solve_inner( L.L0, X, Fi, idx );
0057 
0058         else
0059             error( 'Unknown opts.solver type. Use either ''direct'', ''pcg'' (default) or ''diag''.' )
0060         end
0061 
0062         X = orth_at( X, idx, 'left', true );
0063         
0064         g = apply(L, X) - F;
0065         residuum = [residuum; norm( g ) / normF];
0066         cost = [cost; cost_function_res( X, g )];
0067         times = [times; toc(t_start)];
0068     end
0069 
0070     % ====================================================================
0071     % RIGHT-TO-LEFT
0072     % ====================================================================
0073     disp( 'Starting right-to-left half-sweep:')
0074     for idx = d:-1:2
0075         disp( ['Current core: ', num2str(idx)] )
0076 
0077         Fi = contract( X, F, idx );
0078         sz = [X.rank(idx), X.size(idx), X.rank(idx+1)];
0079         
0080         if strcmpi( opts.solver, 'direct' )
0081             % if system very small
0082             Li = contract( L, X, idx );
0083             Ui = Li \ Fi(:);
0084             X.U{idx} = reshape( Ui, sz );
0085 
0086         elseif strcmpi( opts.solver, 'pcg' )
0087 
0088             [left, right] = Afun_prepare( L, X, idx );
0089             [B2, V, E] =  prepare_precond( L.L0, X, idx );
0090 
0091             Ui = pcg( @(y) Afun( L, y, idx, sz, left, right), ...
0092                      Fi(:), ...
0093                      1e-10, 1000, ...
0094                      @(y) apply_precond( B2, V, E, y, sz ), [],...
0095                      X.U{idx}(:) ); 
0096 
0097 
0098 
0099             X.U{idx} = reshape( Ui, sz );
0100 
0101         elseif strcmpi( opts.solver, 'diag' )
0102             X.U{idx} = solve_inner( L.L0, X, Fi, idx );
0103 
0104         else
0105             error( 'Unknown opts.solver type. Use either ''direct'', ''pcg'' (default) or ''diag''.' )
0106         end
0107 
0108 
0109         X = orth_at( X, idx, 'right', true );
0110         
0111         g = apply(L, X) - F;
0112         residuum = [residuum; norm( g ) / normF];
0113         cost = [cost; cost_function_res( X, g )];
0114         times = [times; toc(t_start)];
0115     end
0116     
0117 end
0118 
0119 
0120 end
0121 
0122 function res = cost_function( L, X, F )
0123 res = 0.5*innerprod( X, apply(L, X) ) - innerprod( X, F );
0124 end
0125 
0126 function res = cost_function_res( X, res )
0127 res = 0.5*innerprod( X, res );
0128 end
0129 
0130 
0131 function [left, right] = Afun_prepare( A, x, idx )
0132     y = A.apply(x); 
0133     if idx == 1
0134         right = innerprod( x, y, 'RL', idx+1 );
0135         left = [];
0136     elseif idx == x.order
0137         left = innerprod( x, y, 'LR', idx-1 );
0138         right = [];
0139     else
0140         left = innerprod( x, y, 'LR', idx-1 );
0141         right = innerprod( x, y, 'RL', idx+1 ); 
0142     end
0143 end
0144 
0145 function res = Afun( A, U, idx, sz, left, right )
0146 
0147     V = reshape( U, sz );
0148     V = A.apply( V, idx );
0149     
0150     if idx == 1
0151         tmp = tensorprod_ttemps( V, right, 3 );
0152     elseif idx == A.order
0153         tmp = tensorprod_ttemps( V, left, 1 );
0154     else
0155         tmp = tensorprod_ttemps( V, right, 3);
0156         tmp = tensorprod_ttemps( tmp, left, 1);
0157     end
0158 
0159     res = tmp(:);
0160 end
0161 
0162 %function res = apply_local_precond( A, U, sz, expB)
0163 %
0164 %    p = size(U, 2);
0165 %
0166 %    x = reshape( U, [sz, p] );
0167 %    res = zeros( [sz, p] );
0168 %
0169 %    for i = 1:size( expB, 1)
0170 %        tmp = reshape( x, [sz(1), sz(2)*sz(3)*p] );
0171 %        tmp = reshape( expB{1,i}*tmp, [sz(1), sz(2), sz(3), p] );
0172 %
0173 %        tmp = reshape( permute( tmp, [2 1 3 4] ), [sz(2), sz(1)*sz(3)*p] );
0174 %        tmp = ipermute( reshape( expB{2,i}*tmp, [sz(2), sz(1), sz(3), p] ), [2 1 3 4] );
0175 %
0176 %        tmp = reshape( permute( tmp, [3 1 2 4] ), [sz(3), sz(1)*sz(2)*p] );
0177 %        tmp = ipermute( reshape( expB{3,i}*tmp, [sz(3), sz(1), sz(2), p] ), [3 1 2 4] );
0178 %
0179 %        res = res + tmp;
0180 %    end
0181 %    res = reshape( res, [prod(sz), p] );
0182 %
0183 %end
0184 
0185 function res = solve_inner( L0, X, Fi, idx )
0186     n = size(L0, 1);
0187     rl = X.rank(idx);
0188     rr = X.rank(idx+1);
0189 
0190     B1 = zeros( rl );
0191     % calculate B1 part:
0192     for i = 1:idx-1
0193         % apply L to the i'th core
0194         tmp = X;
0195         tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0196         B1 = B1 + innerprod( X, tmp, 'LR', idx-1);
0197     end
0198 
0199     % calculate B2 part:
0200     B2 = L0;
0201 
0202     B3 = zeros( rr );
0203     % calculate B3 part:
0204     for i = idx+1:X.order
0205         tmp = X;
0206         tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0207         B3 = B3 + innerprod( X, tmp, 'RL', idx+1);
0208     end
0209 
0210     [V,E] = eig( kron( eye(rr), B1 ) + kron( B3, eye(rl) ) );
0211     E = diag(E);
0212 
0213     rhs = matricize( Fi, 2 ) * V;
0214     Y = zeros(size(rhs));
0215     for i=1:length(E)
0216         Y(:,i) = (B2 + E(i)*speye(n)) \ rhs(:,i);
0217     end
0218     res = tensorize( Y*V', 2, [rl, n, rr] );
0219 end
0220 
0221 function [B2, V, E] = prepare_precond( L0, X, idx )
0222     n = size(L0, 1);
0223     rl = X.rank(idx);
0224     rr = X.rank(idx+1);
0225 
0226     B1 = zeros( rl );
0227     % calculate B1 part:
0228     for i = 1:idx-1
0229         % apply L to the i'th core
0230         tmp = X;
0231         tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0232         B1 = B1 + innerprod( X, tmp, 'LR', idx-1);
0233     end
0234 
0235     % calculate B2 part:
0236     B2 = L0;
0237 
0238     B3 = zeros( rr );
0239     % calculate B3 part:
0240     for i = idx+1:X.order
0241         tmp = X;
0242         tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0243         B3 = B3 + innerprod( X, tmp, 'RL', idx+1);
0244     end
0245 
0246     [V,E] = eig( kron( eye(rr), B1 ) + kron( B3, eye(rl) ) );
0247     E = diag(E);
0248 end
0249 
0250 function res = apply_precond( B2, V, E, rhs, sz )
0251     n = size(B2, 1);
0252     rhs = reshape( rhs, sz );
0253     rhs = matricize( rhs, 2 ) * V;
0254     Y = zeros(size(rhs));
0255     for i=1:length(E)
0256         Y(:,i) = (B2 + E(i)*speye(n)) \ rhs(:,i);
0257     end
0258     res = tensorize( Y*V', 2, sz );
0259     res = res(:);
0260 end
0261 
0262 
0263

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005