Home > manopt > solvers > stochasticgradient > stochasticgradient.m

stochasticgradient

PURPOSE ^

Stochastic gradient (SG) minimization algorithm for Manopt.

SYNOPSIS ^

function [x, info, options] = stochasticgradient(problem, x, options)

DESCRIPTION ^

 Stochastic gradient (SG) minimization algorithm for Manopt.

 function [x, info, options] = stochasticgradient(problem)
 function [x, info, options] = stochasticgradient(problem, x0)
 function [x, info, options] = stochasticgradient(problem, x0, options)
 function [x, info, options] = stochasticgradient(problem, [], options)

 Apply the Riemannian stochastic gradient algorithm to the problem defined
 in the problem structure, starting at x0 if it is provided (otherwise, at
 a random point on the manifold). To specify options whilst not specifying
 an initial guess, give x0 as [] (the empty matrix).

 The problem structure must contain the following fields:

  problem.M:
       Defines the manifold to optimize over, given by a factory.

  problem.partialgrad or problem.partialegrad (or equivalent)
       Describes the partial gradients of the cost function. If the cost
       function is of the form f(x) = sum_{k=1}^N f_k(x),
       then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
       As usual, partialgrad must define the Riemannian gradient, whereas
       partialegrad defines a Euclidean (classical) gradient which will be
       converted automatically to a Riemannian gradient. Use the tool
       checkgradient(problem) to check it. K is a /row/ vector, which
       makes it natural to write for k = K, ..., end.

  problem.ncostterms
       An integer specifying how many terms are in the cost function (in
       the example above, that would be N.)

 Importantly, the cost function itself needs not be specified.

 Some of the options of the solver are specific to this file. Please have
 a look inside the code.

 To record the value of the cost function or the norm of the gradient for
 example (which are statistics the algorithm does not require and hence
 does not compute by default), one can set the following options:

   metrics.cost = @(problem, x) getCost(problem, x);
   metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
   options.statsfun = statsfunhelper(metrics);

 Important caveat: stochastic algorithms usually return an average of the
 last few iterates. Computing averages on manifolds can be expensive.
 Currently, this solver does not compute averages and simply returns the
 last iterate. Using options.statsfun, it is possible for the user to
 compute averages manually. If you have ideas on how to do this
 generically, we welcome feedback. In particular, approximate means could
 be computed with M.pairmean which is available in many geometries.

 See also: steepestdescent

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [x, info, options] = stochasticgradient(problem, x, options)
0002 % Stochastic gradient (SG) minimization algorithm for Manopt.
0003 %
0004 % function [x, info, options] = stochasticgradient(problem)
0005 % function [x, info, options] = stochasticgradient(problem, x0)
0006 % function [x, info, options] = stochasticgradient(problem, x0, options)
0007 % function [x, info, options] = stochasticgradient(problem, [], options)
0008 %
0009 % Apply the Riemannian stochastic gradient algorithm to the problem defined
0010 % in the problem structure, starting at x0 if it is provided (otherwise, at
0011 % a random point on the manifold). To specify options whilst not specifying
0012 % an initial guess, give x0 as [] (the empty matrix).
0013 %
0014 % The problem structure must contain the following fields:
0015 %
0016 %  problem.M:
0017 %       Defines the manifold to optimize over, given by a factory.
0018 %
0019 %  problem.partialgrad or problem.partialegrad (or equivalent)
0020 %       Describes the partial gradients of the cost function. If the cost
0021 %       function is of the form f(x) = sum_{k=1}^N f_k(x),
0022 %       then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
0023 %       As usual, partialgrad must define the Riemannian gradient, whereas
0024 %       partialegrad defines a Euclidean (classical) gradient which will be
0025 %       converted automatically to a Riemannian gradient. Use the tool
0026 %       checkgradient(problem) to check it. K is a /row/ vector, which
0027 %       makes it natural to write for k = K, ..., end.
0028 %
0029 %  problem.ncostterms
0030 %       An integer specifying how many terms are in the cost function (in
0031 %       the example above, that would be N.)
0032 %
0033 % Importantly, the cost function itself needs not be specified.
0034 %
0035 % Some of the options of the solver are specific to this file. Please have
0036 % a look inside the code.
0037 %
0038 % To record the value of the cost function or the norm of the gradient for
0039 % example (which are statistics the algorithm does not require and hence
0040 % does not compute by default), one can set the following options:
0041 %
0042 %   metrics.cost = @(problem, x) getCost(problem, x);
0043 %   metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
0044 %   options.statsfun = statsfunhelper(metrics);
0045 %
0046 % Important caveat: stochastic algorithms usually return an average of the
0047 % last few iterates. Computing averages on manifolds can be expensive.
0048 % Currently, this solver does not compute averages and simply returns the
0049 % last iterate. Using options.statsfun, it is possible for the user to
0050 % compute averages manually. If you have ideas on how to do this
0051 % generically, we welcome feedback. In particular, approximate means could
0052 % be computed with M.pairmean which is available in many geometries.
0053 %
0054 % See also: steepestdescent
0055 
0056 % This file is part of Manopt: www.manopt.org.
0057 % Original authors: Bamdev Mishra <bamdevm@gmail.com>,
0058 %                   Hiroyuki Kasai <kasai@is.uec.ac.jp>, and
0059 %                   Hiroyuki Sato <hsato@ms.kagu.tus.ac.jp>, 22 April 2016.
0060 % Contributors: Nicolas Boumal
0061 % Change log:
0062     
0063 
0064     % Verify that the problem description is sufficient for the solver.
0065     if ~canGetPartialGradient(problem)
0066         warning('manopt:getPartialGradient', ...
0067          'No partial gradient provided. The algorithm will likely abort.');
0068     end
0069     
0070    
0071     % Set local default
0072     localdefaults.maxiter = 1000;       % Maximum number of iterations
0073     localdefaults.batchsize = 1;        % Batchsize (# cost terms per iter)
0074     localdefaults.verbosity = 2;        % Output verbosity (0, 1 or 2)
0075     localdefaults.storedepth = 20;      % Limit amount of caching
0076     
0077     % Check stopping criteria and save stats every checkperiod iterations.
0078     localdefaults.checkperiod = 100;
0079     
0080     % stepsizefun is a function implementing a step size selection
0081     % algorithm. See that function for help with options, which can be
0082     % specified in the options structure passed to the solver directly.
0083     localdefaults.stepsizefun = @stepsize_sg;
0084     
0085     % Merge global and local defaults, then merge w/ user options, if any.
0086     localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0087     if ~exist('options', 'var') || isempty(options)
0088         options = struct();
0089     end
0090     options = mergeOptions(localdefaults, options);
0091     
0092     
0093     assert(options.checkperiod >= 1, ...
0094                  'options.checkperiod must be a positive integer (>= 1).');
0095     
0096     
0097     % If no initial point x is given by the user, generate one at random.
0098     if ~exist('x', 'var') || isempty(x)
0099         x = problem.M.rand();
0100     end
0101     
0102     % Create a store database and get a key for the current x
0103     storedb = StoreDB(options.storedepth);
0104     key = storedb.getNewKey();
0105     
0106     
0107     % Elapsed time for the current set of iterations, where a set of
0108     % iterations comprises options.checkperiod iterations. We do not
0109     % count time spent for such things as logging statistics, as these are
0110     % not relevant to the actual optimization process.
0111     elapsed_time = 0;
0112     
0113     % Total number of completed steps
0114     iter = 0;
0115     
0116     
0117     % Total number of saved stats at this point.
0118     savedstats = 0;
0119     
0120     % Collect and save stats in a struct array info, and preallocate.
0121     stats = savestats();
0122     info(1) = stats;
0123     savedstats = savedstats + 1;
0124     if isinf(options.maxiter)
0125         % We trust that if the user set maxiter = inf, then they defined
0126         % another stopping criterion.
0127         preallocate = 1e5;
0128     else
0129         preallocate = ceil(options.maxiter / options.checkperiod) + 1;
0130     end
0131     info(preallocate).iter = [];
0132     
0133     
0134     % Display information header for the user.
0135     if options.verbosity >= 2
0136         fprintf('    iter       time [s]       step size\n');
0137     end
0138     
0139     
0140     % Main loop.
0141     stop = false;
0142     while iter < options.maxiter
0143         
0144         % Record start time.
0145         start_time = tic();
0146         
0147         % Draw the samples with replacement.
0148         idx_batch = randi(problem.ncostterms, options.batchsize, 1);
0149         
0150         % Compute partial gradient on this batch.
0151         pgrad = getPartialGradient(problem, x, idx_batch, storedb, key);
0152         
0153         % Compute a step size and the corresponding new point x.
0154         [stepsize, newx, newkey, ssstats] = ...
0155                            options.stepsizefun(problem, x, pgrad, iter, ...
0156                                                options, storedb, key);
0157         
0158         % Make the step: transfer iterate, remove cache from previous x.
0159         storedb.removefirstifdifferent(key, newkey);
0160         x = newx;
0161         key = newkey;
0162         
0163         % Make sure we do not use too much memory for the store database.
0164         storedb.purge();
0165         
0166         % Total number of completed steps.
0167         iter = iter + 1;
0168         
0169         % Elapsed time doing actual optimization work so far in this
0170         % set of options.checkperiod iterations.
0171         elapsed_time = elapsed_time + toc(start_time);
0172         
0173         
0174         % Check stopping criteria and save stats every checkperiod iters.
0175         if mod(iter, options.checkperiod) == 0
0176             
0177             % Log statistics for freshly executed iteration.
0178             stats = savestats();
0179             info(savedstats+1) = stats;
0180             savedstats = savedstats + 1;
0181             
0182             % Reset timer.
0183             elapsed_time = 0;
0184             
0185             % Print output.
0186             if options.verbosity >= 2
0187                 fprintf('%8d     %10.2f       %.3e\n', ...
0188                                                iter, stats.time, stepsize);
0189             end
0190             
0191             % Run standard stopping criterion checks.
0192             [stop, reason] = stoppingcriterion(problem, x, ...
0193                                                options, info, savedstats);
0194             if stop
0195                 if options.verbosity >= 1
0196                     fprintf([reason '\n']);
0197                 end
0198                 break;
0199             end
0200         
0201         end
0202 
0203     end
0204     
0205     
0206     % Keep only the relevant portion of the info struct-array.
0207     info = info(1:savedstats);
0208     
0209     
0210     % Display a final information message.
0211     if options.verbosity >= 1
0212         if ~stop
0213             % We stopped not because of stoppingcriterion but because the
0214             % loop came to an end, which means maxiter triggered.
0215             msg = 'Max iteration count reached; options.maxiter = %g.\n';
0216             fprintf(msg, options.maxiter);
0217         end
0218         fprintf('Total time is %f [s] (excludes statsfun)\n', ...
0219                 info(end).time + elapsed_time);
0220     end
0221     
0222     
0223     % Helper function to collect statistics to be saved at
0224     % index checkperiodcount+1 in info.
0225     function stats = savestats()
0226         stats.iter = iter;
0227         if savedstats == 0
0228             stats.time = 0;
0229             stats.stepsize = NaN;
0230             stats.stepsize_stats = [];
0231         else
0232             stats.time = info(savedstats).time + elapsed_time;
0233             stats.stepsize = stepsize;
0234             stats.stepsize_stats = ssstats;
0235         end
0236         stats = applyStatsfun(problem, x, storedb, key, options, stats);
0237     end
0238     
0239 end

Generated on Mon 10-Sep-2018 11:48:06 by m2html © 2005