Home > manopt > solvers > stochasticgradient > stochasticgradient.m

PURPOSE

Stochastic gradient (SG) minimization algorithm for Manopt.

SYNOPSIS

function [x, info, options] = stochasticgradient(problem, x, options)

DESCRIPTION

 Stochastic gradient (SG) minimization algorithm for Manopt.

function [x, info, options] = stochasticgradient(problem)
function [x, info, options] = stochasticgradient(problem, x0)
function [x, info, options] = stochasticgradient(problem, x0, options)
function [x, info, options] = stochasticgradient(problem, [], options)

Apply the Riemannian stochastic gradient algorithm to the problem defined
in the problem structure, starting at x0 if it is provided (otherwise, at
a random point on the manifold). To specify options whilst not specifying
an initial guess, give x0 as [] (the empty matrix).

The problem structure must contain the following fields:

problem.M:
Defines the manifold to optimize over, given by a factory.

Describes the partial gradients of the cost function. If the cost
function is of the form f(x) = sum_{k=1}^N f_k(x),
then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
As usual, partialgrad must define the Riemannian gradient, whereas
partialegrad defines a Euclidean (classical) gradient which will be
converted automatically to a Riemannian gradient. Use the tool
checkgradient(problem) to check it. K is a /row/ vector, which
makes it natural to write for k = K, ..., end.

problem.ncostterms
An integer specifying how many terms are in the cost function (in
the example above, that would be N.)

Importantly, the cost function itself needs not be specified.

Some of the options of the solver are specific to this file. Please have
a look inside the code.

To record the value of the cost function or the norm of the gradient for
example (which are statistics the algorithm does not require and hence
does not compute by default), one can set the following options:

metrics.cost = @(problem, x) getCost(problem, x);
metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
options.statsfun = statsfunhelper(metrics);

Important caveat: stochastic algorithms usually return an average of the
last few iterates. Computing averages on manifolds can be expensive.
Currently, this solver does not compute averages and simply returns the
last iterate. Using options.statsfun, it is possible for the user to
compute averages manually. If you have ideas on how to do this
generically, we welcome feedback. In particular, approximate means could
be computed with M.pairmean which is available in many geometries.

See also: steepestdescent

CROSS-REFERENCE INFORMATION

This function calls:
• StoreDB
• applyStatsfun Apply the statsfun function to a stats structure (for solvers).
• canGetPartialGradient Checks whether the partial gradient can be computed for a given problem.
• getGlobalDefaults Returns a structure with default option values for Manopt.
• getPartialGradient Computes the gradient of a subset of terms in the cost function at x.
• mergeOptions Merges two options structures with one having precedence over the other.
• stoppingcriterion Checks for standard stopping criteria, as a helper to solvers.
• stepsize_sg Standard step size selection algorithm for the stochastic gradient method
This function is called by:
• PCA_stochastic Example of stochastic gradient algorithm in Manopt on a PCA problem.

SOURCE CODE

0001 function [x, info, options] = stochasticgradient(problem, x, options)
0002 % Stochastic gradient (SG) minimization algorithm for Manopt.
0003 %
0004 % function [x, info, options] = stochasticgradient(problem)
0005 % function [x, info, options] = stochasticgradient(problem, x0)
0006 % function [x, info, options] = stochasticgradient(problem, x0, options)
0007 % function [x, info, options] = stochasticgradient(problem, [], options)
0008 %
0009 % Apply the Riemannian stochastic gradient algorithm to the problem defined
0010 % in the problem structure, starting at x0 if it is provided (otherwise, at
0011 % a random point on the manifold). To specify options whilst not specifying
0012 % an initial guess, give x0 as [] (the empty matrix).
0013 %
0014 % The problem structure must contain the following fields:
0015 %
0016 %  problem.M:
0017 %       Defines the manifold to optimize over, given by a factory.
0018 %
0019 %  problem.partialgrad or problem.partialegrad (or equivalent)
0020 %       Describes the partial gradients of the cost function. If the cost
0021 %       function is of the form f(x) = sum_{k=1}^N f_k(x),
0022 %       then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
0023 %       As usual, partialgrad must define the Riemannian gradient, whereas
0024 %       partialegrad defines a Euclidean (classical) gradient which will be
0025 %       converted automatically to a Riemannian gradient. Use the tool
0026 %       checkgradient(problem) to check it. K is a /row/ vector, which
0027 %       makes it natural to write for k = K, ..., end.
0028 %
0029 %  problem.ncostterms
0030 %       An integer specifying how many terms are in the cost function (in
0031 %       the example above, that would be N.)
0032 %
0033 % Importantly, the cost function itself needs not be specified.
0034 %
0035 % Some of the options of the solver are specific to this file. Please have
0036 % a look inside the code.
0037 %
0038 % To record the value of the cost function or the norm of the gradient for
0039 % example (which are statistics the algorithm does not require and hence
0040 % does not compute by default), one can set the following options:
0041 %
0042 %   metrics.cost = @(problem, x) getCost(problem, x);
0043 %   metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
0044 %   options.statsfun = statsfunhelper(metrics);
0045 %
0046 % Important caveat: stochastic algorithms usually return an average of the
0047 % last few iterates. Computing averages on manifolds can be expensive.
0048 % Currently, this solver does not compute averages and simply returns the
0049 % last iterate. Using options.statsfun, it is possible for the user to
0050 % compute averages manually. If you have ideas on how to do this
0051 % generically, we welcome feedback. In particular, approximate means could
0052 % be computed with M.pairmean which is available in many geometries.
0053 %
0055
0056 % This file is part of Manopt: www.manopt.org.
0057 % Original authors: Bamdev Mishra <bamdevm@gmail.com>,
0058 %                   Hiroyuki Kasai <kasai@is.uec.ac.jp>, and
0059 %                   Hiroyuki Sato <hsato@ms.kagu.tus.ac.jp>, 22 April 2016.
0060 % Contributors: Nicolas Boumal
0061 % Change log:
0062
0063
0064     % Verify that the problem description is sufficient for the solver.
0067          'No partial gradient provided. The algorithm will likely abort.');
0068     end
0069
0070
0071     % Set local default
0072     localdefaults.maxiter = 1000;       % Maximum number of iterations
0073     localdefaults.batchsize = 1;        % Batchsize (# cost terms per iter)
0074     localdefaults.verbosity = 2;        % Output verbosity (0, 1 or 2)
0075     localdefaults.storedepth = 20;      % Limit amount of caching
0076
0077     % Check stopping criteria and save stats every checkperiod iterations.
0078     localdefaults.checkperiod = 100;
0079
0080     % stepsizefun is a function implementing a step size selection
0081     % algorithm. See that function for help with options, which can be
0082     % specified in the options structure passed to the solver directly.
0083     localdefaults.stepsizefun = @stepsize_sg;
0084
0085     % Merge global and local defaults, then merge w/ user options, if any.
0086     localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0087     if ~exist('options', 'var') || isempty(options)
0088         options = struct();
0089     end
0090     options = mergeOptions(localdefaults, options);
0091
0092
0093     assert(options.checkperiod >= 1, ...
0094                  'options.checkperiod must be a positive integer (>= 1).');
0095
0096
0097     % If no initial point x is given by the user, generate one at random.
0098     if ~exist('x', 'var') || isempty(x)
0099         x = problem.M.rand();
0100     end
0101
0102     % Create a store database and get a key for the current x
0103     storedb = StoreDB(options.storedepth);
0104     key = storedb.getNewKey();
0105
0106
0107     % Elapsed time for the current set of iterations, where a set of
0108     % iterations comprises options.checkperiod iterations. We do not
0109     % count time spent for such things as logging statistics, as these are
0110     % not relevant to the actual optimization process.
0111     elapsed_time = 0;
0112
0113     % Total number of completed steps
0114     iter = 0;
0115
0116
0117     % Total number of saved stats at this point.
0118     savedstats = 0;
0119
0120     % Collect and save stats in a struct array info, and preallocate.
0121     stats = savestats();
0122     info(1) = stats;
0123     savedstats = savedstats + 1;
0124     if isinf(options.maxiter)
0125         % We trust that if the user set maxiter = inf, then they defined
0126         % another stopping criterion.
0127         preallocate = 1e5;
0128     else
0129         preallocate = ceil(options.maxiter / options.checkperiod) + 1;
0130     end
0131     info(preallocate).iter = [];
0132
0133
0134     % Display information header for the user.
0135     if options.verbosity >= 2
0136         fprintf('    iter       time [s]       step size\n');
0137     end
0138
0139
0140     % Main loop.
0141     stop = false;
0142     while iter < options.maxiter
0143
0144         % Record start time.
0145         start_time = tic();
0146
0147         % Draw the samples with replacement.
0148         idx_batch = randi(problem.ncostterms, options.batchsize, 1);
0149
0150         % Compute partial gradient on this batch.
0151         pgrad = getPartialGradient(problem, x, idx_batch, storedb, key);
0152
0153         % Compute a step size and the corresponding new point x.
0154         [stepsize, newx, newkey, ssstats] = ...
0155                            options.stepsizefun(problem, x, pgrad, iter, ...
0156                                                options, storedb, key);
0157
0158         % Make the step: transfer iterate, remove cache from previous x.
0159         storedb.removefirstifdifferent(key, newkey);
0160         x = newx;
0161         key = newkey;
0162
0163         % Make sure we do not use too much memory for the store database.
0164         storedb.purge();
0165
0166         % Total number of completed steps.
0167         iter = iter + 1;
0168
0169         % Elapsed time doing actual optimization work so far in this
0170         % set of options.checkperiod iterations.
0171         elapsed_time = elapsed_time + toc(start_time);
0172
0173
0174         % Check stopping criteria and save stats every checkperiod iters.
0175         if mod(iter, options.checkperiod) == 0
0176
0177             % Log statistics for freshly executed iteration.
0178             stats = savestats();
0179             info(savedstats+1) = stats;
0180             savedstats = savedstats + 1;
0181
0182             % Reset timer.
0183             elapsed_time = 0;
0184
0185             % Print output.
0186             if options.verbosity >= 2
0187                 fprintf('%8d     %10.2f       %.3e\n', ...
0188                                                iter, stats.time, stepsize);
0189             end
0190
0191             % Run standard stopping criterion checks.
0192             [stop, reason] = stoppingcriterion(problem, x, ...
0193                                                options, info, savedstats);
0194             if stop
0195                 if options.verbosity >= 1
0196                     fprintf([reason '\n']);
0197                 end
0198                 break;
0199             end
0200
0201         end
0202
0203     end
0204
0205
0206     % Keep only the relevant portion of the info struct-array.
0207     info = info(1:savedstats);
0208
0209
0210     % Display a final information message.
0211     if options.verbosity >= 1
0212         if ~stop
0213             % We stopped not because of stoppingcriterion but because the
0214             % loop came to an end, which means maxiter triggered.
0215             msg = 'Max iteration count reached; options.maxiter = %g.\n';
0216             fprintf(msg, options.maxiter);
0217         end
0218         fprintf('Total time is %f [s] (excludes statsfun)\n', ...
0219                 info(end).time + elapsed_time);
0220     end
0221
0222
0223     % Helper function to collect statistics to be saved at
0224     % index checkperiodcount+1 in info.
0225     function stats = savestats()
0226         stats.iter = iter;
0227         if savedstats == 0
0228             stats.time = 0;
0229             stats.stepsize = NaN;
0230             stats.stepsize_stats = [];
0231         else
0232             stats.time = info(savedstats).time + elapsed_time;
0233             stats.stepsize = stepsize;
0234             stats.stepsize_stats = ssstats;
0235         end
0236         stats = applyStatsfun(problem, x, storedb, key, options, stats);
0237     end
0238
0239 end

Generated on Mon 10-Sep-2018 11:48:06 by m2html © 2005