Home > manopt > solvers > neldermead > neldermead.m

neldermead

PURPOSE ^

Nelder Mead optimization algorithm for derivative-free minimization.

SYNOPSIS ^

function [x, cost, info, options] = neldermead(problem, x, options)

DESCRIPTION ^

 Nelder Mead optimization algorithm for derivative-free minimization.

 function [x, cost, info, options] = neldermead(problem)
 function [x, cost, info, options] = neldermead(problem, x0)
 function [x, cost, info, options] = neldermead(problem, x0, options)
 function [x, cost, info, options] = neldermead(problem, [], options)

 Apply a Nelder-Mead minimization algorithm to the problem defined in
 the problem structure, starting with the population x0 if it is provided
 (otherwise, a random population on the manifold is generated). A
 population is a cell containing points on the manifold. The number of
 elements in the cell must be dim+1, where dim is the dimension of the
 manifold: problem.M.dim().

 To specify options whilst not specifying an initial guess, give x0 as []
 (the empty matrix).

 This algorithm is a plain adaptation of the Euclidean Nelder-Mead method
 to the Riemannian setting. It comes with no convergence guarantees and
 there is room for improvement. In particular, we compute centroids as
 Karcher means, which seems overly expensive: cheaper forms of
 average-like quantities might work better.
 This solver is useful nonetheless for problems for which no derivatives
 are available, and it may constitute a starting point for the development
 of other Riemannian derivative-free methods.

 None of the options are mandatory. See in code for details.

 Requires problem.M.pairmean(x, y) to be defined (computes the average
 between two points, x and y).

 If options.statsfun is defined, it will receive a cell of points x (the
 current simplex being considered at that iteration), and, if required,
 one store structure corresponding to the best point, x{1}. The points are
 ordered by increasing cost: f(x{1}) <= f(x{2}) <= ... <= f(x{dim+1}),
 where dim = problem.M.dim().

 Based on http://www.optimization-online.org/DB_FILE/2007/08/1742.pdf.

 See also: manopt/solvers/pso/pso

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [x, cost, info, options] = neldermead(problem, x, options)
0002 % Nelder Mead optimization algorithm for derivative-free minimization.
0003 %
0004 % function [x, cost, info, options] = neldermead(problem)
0005 % function [x, cost, info, options] = neldermead(problem, x0)
0006 % function [x, cost, info, options] = neldermead(problem, x0, options)
0007 % function [x, cost, info, options] = neldermead(problem, [], options)
0008 %
0009 % Apply a Nelder-Mead minimization algorithm to the problem defined in
0010 % the problem structure, starting with the population x0 if it is provided
0011 % (otherwise, a random population on the manifold is generated). A
0012 % population is a cell containing points on the manifold. The number of
0013 % elements in the cell must be dim+1, where dim is the dimension of the
0014 % manifold: problem.M.dim().
0015 %
0016 % To specify options whilst not specifying an initial guess, give x0 as []
0017 % (the empty matrix).
0018 %
0019 % This algorithm is a plain adaptation of the Euclidean Nelder-Mead method
0020 % to the Riemannian setting. It comes with no convergence guarantees and
0021 % there is room for improvement. In particular, we compute centroids as
0022 % Karcher means, which seems overly expensive: cheaper forms of
0023 % average-like quantities might work better.
0024 % This solver is useful nonetheless for problems for which no derivatives
0025 % are available, and it may constitute a starting point for the development
0026 % of other Riemannian derivative-free methods.
0027 %
0028 % None of the options are mandatory. See in code for details.
0029 %
0030 % Requires problem.M.pairmean(x, y) to be defined (computes the average
0031 % between two points, x and y).
0032 %
0033 % If options.statsfun is defined, it will receive a cell of points x (the
0034 % current simplex being considered at that iteration), and, if required,
0035 % one store structure corresponding to the best point, x{1}. The points are
0036 % ordered by increasing cost: f(x{1}) <= f(x{2}) <= ... <= f(x{dim+1}),
0037 % where dim = problem.M.dim().
0038 %
0039 % Based on http://www.optimization-online.org/DB_FILE/2007/08/1742.pdf.
0040 %
0041 % See also: manopt/solvers/pso/pso
0042 
0043 % This file is part of Manopt: www.manopt.org.
0044 % Original author: Nicolas Boumal, Dec. 30, 2012.
0045 % Contributors:
0046 % Change log:
0047 %
0048 %   April 4, 2015 (NB):
0049 %       Working with the new StoreDB class system.
0050 %       Clarified interactions with statsfun and store.
0051 %
0052 %   Nov. 11, 2016 (NB):
0053 %       If options.verbosity is < 2, prints minimal output.
0054 
0055     
0056     % Verify that the problem description is sufficient for the solver.
0057     if ~canGetCost(problem)
0058         warning('manopt:getCost', ...
0059                 'No cost provided. The algorithm will likely abort.');  
0060     end
0061     
0062     % Dimension of the manifold
0063     dim = problem.M.dim();
0064 
0065     % Set local defaults here
0066     localdefaults.storedepth = 0;                     % no need for caching
0067     localdefaults.maxcostevals = max(1000, 2*dim);
0068     localdefaults.maxiter = max(2000, 4*dim);
0069     
0070     localdefaults.reflection = 1;
0071     localdefaults.expansion = 2;
0072     localdefaults.contraction = .5;
0073     % forced to .5 to enable using pairmean functions in manifolds.
0074     % localdefaults.shrinkage = .5;
0075     
0076     % Merge global and local defaults, then merge w/ user options, if any.
0077     localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0078     if ~exist('options', 'var') || isempty(options)
0079         options = struct();
0080     end
0081     options = mergeOptions(localdefaults, options);
0082     
0083     % Start timing for initialization.
0084     timetic = tic();
0085     
0086     % If no initial simplex x is given by the user, generate one at random.
0087     if ~exist('x', 'var') || isempty(x)
0088         x = cell(dim+1, 1);
0089         for i = 1 : dim+1
0090             x{i} = problem.M.rand();
0091         end
0092     end
0093     
0094     % Create a store database and a key for each point.
0095     storedb = StoreDB(options.storedepth);
0096     key = cell(size(x));
0097     for i = 1 : dim+1;
0098         key{i} = storedb.getNewKey();
0099     end
0100     
0101     % Compute objective-related quantities for x, and setup a
0102     % function evaluations counter.
0103     costs = zeros(dim+1, 1);
0104     for i = 1 : dim+1
0105         costs(i) = getCost(problem, x{i}, storedb, key{i});
0106     end
0107     costevals = dim+1;
0108     
0109     % Sort simplex points by cost.
0110     [costs, order] = sort(costs);
0111     x = x(order);
0112     key = key(order);
0113     
0114     % Iteration counter.
0115     % At any point, iter is the number of fully executed iterations so far.
0116     iter = 0;
0117     
0118     % Save stats in a struct array info, and preallocate.
0119     % savestats will be called twice for the initial iterate (number 0),
0120     % which is unfortunate, but not problematic.
0121     stats = savestats();
0122     info(1) = stats;
0123     info(min(10000, options.maxiter+1)).iter = [];
0124     
0125     % Start iterating until stopping criterion triggers.
0126     while true
0127         
0128         % Make sure we don't use to much memory for the store database.
0129         storedb.purge();
0130         
0131         stats = savestats();
0132         info(iter+1) = stats; %#ok<AGROW>
0133         iter = iter + 1;
0134         
0135         % Start timing this iteration.
0136         timetic = tic();
0137         
0138         % Sort simplex points by cost.
0139         [costs, order] = sort(costs);
0140         x = x(order);
0141         key = key(order);
0142 
0143         % Log / display iteration information here.
0144         if options.verbosity >= 2
0145             fprintf('Cost evals: %7d\tBest cost: %+.4e\t', ...
0146                     costevals, costs(1));
0147         end
0148         
0149         % Run standard stopping criterion checks.
0150         [stop, reason] = stoppingcriterion(problem, x, options, info, iter);
0151     
0152         if stop
0153             if options.verbosity >= 1
0154                 fprintf([reason '\n']);
0155             end
0156             break;
0157         end
0158         
0159         % Compute a centroid for the dim best points.
0160         xbar = centroid(problem.M, x(1:end-1));
0161         
0162         % Compute the direction for moving along the axis xbar - worst x.
0163         vec = problem.M.log(xbar, x{end});
0164         
0165         % Reflection step
0166         xr = problem.M.exp(xbar, vec, -options.reflection);
0167         keyr = storedb.getNewKey();
0168         costr = getCost(problem, xr, storedb, keyr);
0169         costevals = costevals + 1;
0170         
0171         % If the reflected point is honorable, drop the worst point,
0172         % replace it by the reflected point and start new iteration.
0173         if costr >= costs(1) && costr < costs(end-1)
0174             if options.verbosity >= 2
0175                 fprintf('Reflection\n');
0176             end
0177             costs(end) = costr;
0178             x{end} = xr;
0179             key{end} = keyr;
0180             continue;
0181         end
0182         
0183         % If the reflected point is better than the best point, expand.
0184         if costr < costs(1)
0185             xe = problem.M.exp(xbar, vec, -options.expansion);
0186             keye = storedb.getNewKey();
0187             coste = getCost(problem, xe, storedb, keye);
0188             costevals = costevals + 1;
0189             if coste < costr
0190                 if options.verbosity >= 2
0191                     fprintf('Expansion\n');
0192                 end
0193                 costs(end) = coste;
0194                 x{end} = xe;
0195                 key{end} = keye;
0196                 continue;
0197             else
0198                 if options.verbosity >= 2
0199                     fprintf('Reflection (failed expansion)\n');
0200                 end
0201                 costs(end) = costr;
0202                 x{end} = xr;
0203                 key{end} = keyr;
0204                 continue;
0205             end
0206         end
0207         
0208         % If the reflected point is worse than the second to worst point,
0209         % contract.
0210         if costr >= costs(end-1)
0211             if costr < costs(end)
0212                 % do an outside contraction
0213                 xoc = problem.M.exp(xbar, vec, -options.contraction);
0214                 keyoc = storedb.getNewKey();
0215                 costoc = getCost(problem, xoc, storedb, keyoc);
0216                 costevals = costevals + 1;
0217                 if costoc <= costr
0218                     if options.verbosity >= 2
0219                         fprintf('Outside contraction\n');
0220                     end
0221                     costs(end) = costoc;
0222                     x{end} = xoc;
0223                     key{end} = keyoc;
0224                     continue;
0225                 end
0226             else
0227                 % do an inside contraction
0228                 xic = problem.M.exp(xbar, vec, options.contraction);
0229                 keyic = storedb.getNewKey();
0230                 costic = getCost(problem, xic, storedb, keyic);
0231                 costevals = costevals + 1;
0232                 if costic <= costs(end)
0233                     if options.verbosity >= 2
0234                         fprintf('Inside contraction\n');
0235                     end
0236                     costs(end) = costic;
0237                     x{end} = xic;
0238                     key{end} = keyic;
0239                     continue;
0240                 end
0241             end
0242         end
0243         
0244         % If we get here, shrink the simplex around x{1}.
0245         if options.verbosity >= 2
0246             fprintf('Shrinkage\n');
0247         end
0248         for i = 2 : dim+1
0249             x{i} = problem.M.pairmean(x{1}, x{i});
0250             key{i} = storedb.getNewKey();
0251             costs(i) = getCost(problem, x{i}, storedb, key{i});
0252         end
0253         costevals = costevals + dim;
0254         
0255     end
0256     
0257     
0258     info = info(1:iter);
0259     
0260     % Iteration done: return only the best point found.
0261     cost = costs(1);
0262     x = x{1};
0263     key = key{1};
0264     
0265     
0266     
0267     % Routine in charge of collecting the current iteration stats.
0268     function stats = savestats()
0269         stats.iter = iter;
0270         stats.cost = costs(1);
0271         stats.costevals = costevals;
0272         if iter == 0
0273             stats.time = toc(timetic);
0274         else
0275             stats.time = info(iter).time + toc(timetic);
0276         end
0277         % The statsfun can only possibly receive one store structure. We
0278         % pass the key to the best point, so that the best point's store
0279         % will be passed. But the whole cell x of points is passed through.
0280         stats = applyStatsfun(problem, x, storedb, key{1}, options, stats);
0281     end
0282     
0283 end

Generated on Sat 12-Nov-2016 14:11:22 by m2html © 2005