Helper for trs_tCG_cached: mimics the latter's behavior, exploiting cache function trsoutput = tCG_rejectedstep(problem, trsinput, options, store) This function is a companion to trs_tCG_cached: it is not meant to be called directly by Manopt users. When running trustregions, the tCG subproblem solver issues a number of Hessian-vector calls. If the step is rejected, the trust-region radius is decreased, then tCG is called anew, at the same point. This triggers the same Hessian-vector calls to be issued. Instead of actually making those calls (which tend to be computationally expensive), trs_tCG_cached calls this function, which exploits information cached in the previous call to avoid redundant computations. The output is exactly the same as what one would have obtained if calling tCG without caching. There can be two situations: 1. The same eta and Heta as the previous tCG loop is returned and trustregions decreases Delta. (Either d_Hd <= 0 or store_last is used.) 2. A new eta and Heta is returned when some previously calculated eta_new from store_iters satisfies normsq := <eta_new,eta_new>_x >= Delta^2 at the current Delta (exceeding trust region). Then the returned point is the Steihaug–Toint point calculated using the eta computed before eta_new. Refer to trs_tCG_cached for a description of the inputs and outputs. See also: trustregions trs_tCG_cached trs_tCG
0001 function trsoutput = tCG_rejectedstep(problem, trsinput, options, store) 0002 % Helper for trs_tCG_cached: mimics the latter's behavior, exploiting cache 0003 % 0004 % function trsoutput = tCG_rejectedstep(problem, trsinput, options, store) 0005 % 0006 % This function is a companion to trs_tCG_cached: it is not meant to be 0007 % called directly by Manopt users. 0008 % 0009 % When running trustregions, the tCG subproblem solver issues a number of 0010 % Hessian-vector calls. If the step is rejected, the trust-region radius is 0011 % decreased, then tCG is called anew, at the same point. This triggers the 0012 % same Hessian-vector calls to be issued. Instead of actually making those 0013 % calls (which tend to be computationally expensive), trs_tCG_cached calls 0014 % this function, which exploits information cached in the previous call to 0015 % avoid redundant computations. The output is exactly the same as what one 0016 % would have obtained if calling tCG without caching. 0017 % 0018 % There can be two situations: 0019 % 0020 % 1. The same eta and Heta as the previous tCG loop is returned and 0021 % trustregions decreases Delta. 0022 % (Either d_Hd <= 0 or store_last is used.) 0023 % 0024 % 2. A new eta and Heta is returned when some previously calculated eta_new 0025 % from store_iters satisfies normsq := <eta_new,eta_new>_x >= Delta^2 0026 % at the current Delta (exceeding trust region). Then the returned point 0027 % is the Steihaug–Toint point calculated using the eta computed before 0028 % eta_new. 0029 % 0030 % Refer to trs_tCG_cached for a description of the inputs and outputs. 0031 % 0032 % See also: trustregions trs_tCG_cached trs_tCG 0033 0034 % This file is part of Manopt: www.manopt.org. 0035 % Original author: Victor Liao, Jun. 24, 2022. 0036 % Contributors: Nicolas Boumal 0037 % Change log: 0038 0039 x = trsinput.x; 0040 Delta = trsinput.Delta; 0041 0042 lincomb = @(a, u, b, v) problem.M.lincomb(x, a, u, b, v); 0043 0044 store_iters = store.store_iters; 0045 stats.memorytCG_MB = getsize(store_iters(1))/1024^2 * length(store_iters); 0046 numstored = length(store_iters); 0047 if isfield(store, 'store_last') 0048 store_last = store.store_last; 0049 stats.memorytCG_MB = stats.memorytCG_MB + getsize(store_last)/1024^2; 0050 numstored = numstored + 1; 0051 end 0052 0053 limitedbyTR = false; 0054 printstr = ''; 0055 0056 for ii = 1:length(store_iters) 0057 normsq = store_iters(ii).normsq; 0058 d_Hd = store_iters(ii).d_Hd; 0059 if d_Hd <= 0 || normsq >= Delta^2 0060 % We exit after computing new eta, Heta dependent on Delta 0061 e_Pe = store_iters(ii).e_Pe; 0062 e_Pd = store_iters(ii).e_Pd; 0063 d_Pd = store_iters(ii).d_Pd; 0064 eta = store_iters(ii).eta; 0065 mdelta = store_iters(ii).mdelta; 0066 Hmdelta = store_iters(ii).Hmdelta; 0067 Heta = store_iters(ii).Heta; 0068 0069 tau = (-e_Pd + sqrt(e_Pd*e_Pd + d_Pd*(Delta^2-e_Pe))) / d_Pd; 0070 if options.debug > 2 0071 fprintf('DBG: tau : %e\n', tau); 0072 end 0073 eta = lincomb(1, eta, -tau, mdelta); 0074 0075 % If only a nonlinear Hessian approximation is available, this 0076 % is only approximately correct, but saves an additional 0077 % Hessian call. 0078 Heta = lincomb(1, Heta, -tau, Hmdelta); 0079 0080 % Technically, we may want to verify that the new eta is indeed 0081 % better than the previous eta before returning it (this is 0082 % always the case if the Hessian approximation is linear, but 0083 % unsure whether it is the case for nonlinear approximations.) 0084 % At any rate, the impact should be limited, so in the interest 0085 % of code conciseness, we omit this. 0086 0087 if d_Hd <= 0 0088 stopreason_str = 'negative curvature'; 0089 else 0090 stopreason_str = 'exceeded trust region'; 0091 end 0092 0093 limitedbyTR = true; 0094 0095 stats.numinner = store_iters(ii).numinner; 0096 stats.hessvecevals = 0; 0097 0098 if options.verbosity == 2 0099 printstr = sprintf('%9d %9d %9d %s', ... 0100 stats.numinner, 0, numstored, ... 0101 stopreason_str); 0102 elseif options.verbosity > 2 0103 printstr = sprintf('%9d %9d %9d %9.2f %s', ... 0104 stats.numinner, 0, numstored, ... 0105 stats.memorytCG_MB, stopreason_str); 0106 end 0107 0108 trsoutput.eta = eta; 0109 trsoutput.Heta = Heta; 0110 trsoutput.limitedbyTR = limitedbyTR; 0111 trsoutput.printstr = printstr; 0112 trsoutput.stats = stats; 0113 return; 0114 end 0115 end 0116 0117 % If no stored struct in store_iters satisfies negative curvature or 0118 % violates the trust-region radius we exit with last eta, Heta and 0119 % limitedbyTR = false from store_last. 0120 eta = store_last.eta; 0121 Heta = store_last.Heta; 0122 stats.numinner = store_last.numinner; 0123 stats.hessvecevals = 0; 0124 if options.verbosity == 2 0125 printstr = sprintf('%9d %9d %9d %s', ... 0126 stats.numinner, 0, numstored, ... 0127 store_last.stopreason_str); 0128 elseif options.verbosity > 2 0129 printstr = sprintf('%9d %9d %9d %9.2f %s', ... 0130 stats.numinner, 0, numstored, ... 0131 stats.memorytCG_MB, store_last.stopreason_str); 0132 end 0133 0134 trsoutput.eta = eta; 0135 trsoutput.Heta = Heta; 0136 trsoutput.limitedbyTR = limitedbyTR; 0137 trsoutput.printstr = printstr; 0138 trsoutput.stats = stats; 0139 0140 end