Home > manopt > solvers > stochasticgradient > stochasticgradient.m

stochasticgradient

PURPOSE ^

Stochastic gradient (SG) minimization algorithm for Manopt.

SYNOPSIS ^

function [x, info, options] = stochasticgradient(problem, x, options)

DESCRIPTION ^

 Stochastic gradient (SG) minimization algorithm for Manopt.

 function [x, info, options] = stochasticgradient(problem)
 function [x, info, options] = stochasticgradient(problem, x0)
 function [x, info, options] = stochasticgradient(problem, x0, options)
 function [x, info, options] = stochasticgradient(problem, [], options)

 Apply the Riemannian stochastic gradient algorithm to the problem defined
 in the problem structure, starting at x0 if it is provided (otherwise, at
 a random point on the manifold). To specify options whilst not specifying
 an initial guess, give x0 as [] (the empty matrix).

 The problem structure must contain the following fields:

  problem.M:
       Defines the manifold to optimize over, given by a factory.

  problem.partialgrad or problem.partialegrad (or equivalent)
       Describes the partial gradients of the cost function. If the cost
       function is of the form f(x) = sum_{k=1}^N f_k(x),
       then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
       As usual, partialgrad must define the Riemannian gradient, whereas
       partialegrad defines a Euclidean (classical) gradient which will be
       converted automatically to a Riemannian gradient. Use the tool
       checkgradient(problem) to check it.

  problem.ncostterms
       An integer specifying how many terms are in the cost function (in
       the example above, that would be N.)

 Importantly, the cost function itself needs not be specified.

 Some of the options of the solver are specific to this file. Please have
 a look inside the code.

 To record the value of the cost function or the norm of the gradient for
 example (which are statistics the algorithm does not require and hence
 does not compute by default), one can set the following options:

   metrics.cost = @(problem, x) getCost(problem, x);
   metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
   options.statsfun = statsfunhelper(metrics);

 Important caveat: stochastic algorithms usually return an average of the
 last few iterates. Computing averages on manifolds can be expensive.
 Currently, this solver does not compute averages and simply returns the
 last iterate. Using options.statsfun, it is possible for the user to
 compute averages manually. If you have ideas on how to do this
 generically, we welcome feedback. In particular, approximate means could
 be computed with M.pairmean which is available in many geometries.

 See also: steepestdescent

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [x, info, options] = stochasticgradient(problem, x, options)
0002 % Stochastic gradient (SG) minimization algorithm for Manopt.
0003 %
0004 % function [x, info, options] = stochasticgradient(problem)
0005 % function [x, info, options] = stochasticgradient(problem, x0)
0006 % function [x, info, options] = stochasticgradient(problem, x0, options)
0007 % function [x, info, options] = stochasticgradient(problem, [], options)
0008 %
0009 % Apply the Riemannian stochastic gradient algorithm to the problem defined
0010 % in the problem structure, starting at x0 if it is provided (otherwise, at
0011 % a random point on the manifold). To specify options whilst not specifying
0012 % an initial guess, give x0 as [] (the empty matrix).
0013 %
0014 % The problem structure must contain the following fields:
0015 %
0016 %  problem.M:
0017 %       Defines the manifold to optimize over, given by a factory.
0018 %
0019 %  problem.partialgrad or problem.partialegrad (or equivalent)
0020 %       Describes the partial gradients of the cost function. If the cost
0021 %       function is of the form f(x) = sum_{k=1}^N f_k(x),
0022 %       then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
0023 %       As usual, partialgrad must define the Riemannian gradient, whereas
0024 %       partialegrad defines a Euclidean (classical) gradient which will be
0025 %       converted automatically to a Riemannian gradient. Use the tool
0026 %       checkgradient(problem) to check it.
0027 %
0028 %  problem.ncostterms
0029 %       An integer specifying how many terms are in the cost function (in
0030 %       the example above, that would be N.)
0031 %
0032 % Importantly, the cost function itself needs not be specified.
0033 %
0034 % Some of the options of the solver are specific to this file. Please have
0035 % a look inside the code.
0036 %
0037 % To record the value of the cost function or the norm of the gradient for
0038 % example (which are statistics the algorithm does not require and hence
0039 % does not compute by default), one can set the following options:
0040 %
0041 %   metrics.cost = @(problem, x) getCost(problem, x);
0042 %   metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
0043 %   options.statsfun = statsfunhelper(metrics);
0044 %
0045 % Important caveat: stochastic algorithms usually return an average of the
0046 % last few iterates. Computing averages on manifolds can be expensive.
0047 % Currently, this solver does not compute averages and simply returns the
0048 % last iterate. Using options.statsfun, it is possible for the user to
0049 % compute averages manually. If you have ideas on how to do this
0050 % generically, we welcome feedback. In particular, approximate means could
0051 % be computed with M.pairmean which is available in many geometries.
0052 %
0053 % See also: steepestdescent
0054 
0055 % This file is part of Manopt: www.manopt.org.
0056 % Original authors: Bamdev Mishra <bamdevm@gmail.com>,
0057 %                   Hiroyuki Kasai <kasai@is.uec.ac.jp>, and
0058 %                   Hiroyuki Sato <hsato@ms.kagu.tus.ac.jp>, 22 April 2016.
0059 % Contributors: Nicolas Boumal
0060 % Change log:
0061     
0062 
0063     % Verify that the problem description is sufficient for the solver.
0064     if ~canGetPartialGradient(problem)
0065         warning('manopt:getPartialGradient', ...
0066          'No partial gradient provided. The algorithm will likely abort.');
0067     end
0068     
0069    
0070     % Set local default
0071     localdefaults.maxiter = 1000;       % Maximum number of iterations
0072     localdefaults.batchsize = 1;        % Batchsize (# cost terms per iter)
0073     localdefaults.verbosity = 2;        % Output verbosity (0, 1 or 2)
0074     localdefaults.storedepth = 20;      % Limit amount of caching
0075     
0076     % Check stopping criteria and save stats every checkperiod iterations.
0077     localdefaults.checkperiod = 100;
0078     
0079     % stepsizefun is a function implementing a step size selection
0080     % algorithm. See that function for help with options, which can be
0081     % specified in the options structure passed to the solver directly.
0082     localdefaults.stepsizefun = @stepsize_sg;
0083     
0084     % Merge global and local defaults, then merge w/ user options, if any.
0085     localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0086     if ~exist('options', 'var') || isempty(options)
0087         options = struct();
0088     end
0089     options = mergeOptions(localdefaults, options);
0090     
0091     
0092     assert(options.checkperiod >= 1, ...
0093                  'options.checkperiod must be a positive integer (>= 1).');
0094     
0095     
0096     % If no initial point x is given by the user, generate one at random.
0097     if ~exist('x', 'var') || isempty(x)
0098         x = problem.M.rand();
0099     end
0100     
0101     % Create a store database and get a key for the current x
0102     storedb = StoreDB(options.storedepth);
0103     key = storedb.getNewKey();
0104     
0105     
0106     % Elapsed time for the current set of iterations, where a set of
0107     % iterations comprises options.checkperiod iterations. We do not
0108     % count time spent for such things as logging statistics, as these are
0109     % not relevant to the actual optimization process.
0110     elapsed_time = 0;
0111     
0112     % Total number of completed steps
0113     iter = 0;
0114     
0115     
0116     % Total number of saved stats at this point.
0117     savedstats = 0;
0118     
0119     % Collect and save stats in a struct array info, and preallocate.
0120     stats = savestats();
0121     info(1) = stats;
0122     savedstats = savedstats + 1;
0123     if isinf(options.maxiter)
0124         % We trust that if the user set maxiter = inf, then they defined
0125         % another stopping criterion.
0126         preallocate = 1e5;
0127     else
0128         preallocate = ceil(options.maxiter / options.checkperiod) + 1;
0129     end
0130     info(preallocate).iter = [];
0131     
0132     
0133     % Display information header for the user.
0134     if options.verbosity >= 2
0135         fprintf('    iter       time [s]       step size\n');
0136     end
0137     
0138     
0139     % Main loop.
0140     stop = false;
0141     while iter < options.maxiter
0142         
0143         % Record start time.
0144         start_time = tic();
0145         
0146         % Draw the samples with replacement.
0147         idx_batch = randi(problem.ncostterms, options.batchsize, 1);
0148         
0149         % Compute partial gradient on this batch.
0150         pgrad = getPartialGradient(problem, x, idx_batch, storedb, key);
0151         
0152         % Compute a step size and the corresponding new point x.
0153         [stepsize, newx, newkey, ssstats] = ...
0154                            options.stepsizefun(problem, x, pgrad, iter, ...
0155                                                options, storedb, key);
0156         
0157         % Make the step.
0158         x = newx;
0159         key = newkey;
0160         
0161         % Total number of completed steps.
0162         iter = iter + 1;
0163         
0164         % Make sure we do not use too much memory for the store database.
0165         storedb.purge();
0166         
0167         % Elapsed time doing actual optimization work so far in this
0168         % set of options.checkperiod iterations.
0169         elapsed_time = elapsed_time + toc(start_time);
0170         
0171         
0172         % Check stopping criteria and save stats every checkperiod iters.
0173         if mod(iter, options.checkperiod) == 0
0174             
0175             % Log statistics for freshly executed iteration.
0176             stats = savestats();
0177             info(savedstats+1) = stats;
0178             savedstats = savedstats + 1;
0179             
0180             % Reset timer.
0181             elapsed_time = 0;
0182             
0183             % Print output.
0184             if options.verbosity >= 2
0185                 fprintf('%8d     %10.2f       %.3e\n', ...
0186                                                iter, stats.time, stepsize);
0187             end
0188             
0189             % Run standard stopping criterion checks.
0190             [stop, reason] = stoppingcriterion(problem, x, ...
0191                                                options, info, savedstats);
0192             if stop
0193                 if options.verbosity >= 1
0194                     fprintf([reason '\n']);
0195                 end
0196                 break;
0197             end
0198         
0199         end
0200 
0201     end
0202     
0203     
0204     % Keep only the relevant portion of the info struct-array.
0205     info = info(1:savedstats);
0206     
0207     
0208     % Display a final information message.
0209     if options.verbosity >= 1
0210         if ~stop
0211             % We stopped not because of stoppingcriterion but because the
0212             % loop came to an end, which means maxiter triggered.
0213             msg = 'Max iteration count reached; options.maxiter = %g.\n';
0214             fprintf(msg, options.maxiter);
0215         end
0216         fprintf('Total time is %f [s] (excludes statsfun)\n', ...
0217                 info(end).time + elapsed_time);
0218     end
0219     
0220     
0221     % Helper function to collect statistics to be saved at
0222     % index checkperiodcount+1 in info.
0223     function stats = savestats()
0224         stats.iter = iter;
0225         if savedstats == 0
0226             stats.time = 0;
0227             stats.stepsize = NaN;
0228             stats.stepsize_stats = [];
0229         else
0230             stats.time = info(savedstats).time + elapsed_time;
0231             stats.stepsize = stepsize;
0232             stats.stepsize_stats = ssstats;
0233         end
0234         stats = applyStatsfun(problem, x, storedb, key, options, stats);
0235     end
0236     
0237 end

Generated on Fri 08-Sep-2017 12:43:19 by m2html © 2005