Nelder Mead optimization algorithm for derivative-free minimization. function [x, cost, info, options] = neldermead(problem) function [x, cost, info, options] = neldermead(problem, x0) function [x, cost, info, options] = neldermead(problem, x0, options) function [x, cost, info, options] = neldermead(problem, [], options) Apply a Nelder-Mead minimization algorithm to the problem defined in the problem structure, starting with the population x0 if it is provided (otherwise, a random population on the manifold is generated). A population is a cell containing points on the manifold. The number of elements in the cell must be dim+1, where dim is the dimension of the manifold: problem.M.dim(). To specify options whilst not specifying an initial guess, give x0 as [] (the empty matrix). This algorithm is a plain adaptation of the Euclidean Nelder-Mead method to the Riemannian setting. It comes with no convergence guarantees and there is room for improvement. In particular, we compute centroids as Karcher means, which seems overly expensive: cheaper forms of average-like quantities might work better. This solver is useful nonetheless for problems for which no derivatives are available, and it may constitute a starting point for the development of other Riemannian derivative-free methods. None of the options are mandatory. See in code for details. Requires problem.M.pairmean(x, y) to be defined (computes the average between two points, x and y). If options.statsfun is defined, it will receive a cell of points x (the current simplex being considered at that iteration), and, if required, one store structure corresponding to the best point, x{1}. The points are ordered by increasing cost: f(x{1}) <= f(x{2}) <= ... <= f(x{dim+1}), where dim = problem.M.dim(). Based on http://www.optimization-online.org/DB_FILE/2007/08/1742.pdf. See also: manopt/solvers/pso/pso
0001 function [x, cost, info, options] = neldermead(problem, x, options) 0002 % Nelder Mead optimization algorithm for derivative-free minimization. 0003 % 0004 % function [x, cost, info, options] = neldermead(problem) 0005 % function [x, cost, info, options] = neldermead(problem, x0) 0006 % function [x, cost, info, options] = neldermead(problem, x0, options) 0007 % function [x, cost, info, options] = neldermead(problem, [], options) 0008 % 0009 % Apply a Nelder-Mead minimization algorithm to the problem defined in 0010 % the problem structure, starting with the population x0 if it is provided 0011 % (otherwise, a random population on the manifold is generated). A 0012 % population is a cell containing points on the manifold. The number of 0013 % elements in the cell must be dim+1, where dim is the dimension of the 0014 % manifold: problem.M.dim(). 0015 % 0016 % To specify options whilst not specifying an initial guess, give x0 as [] 0017 % (the empty matrix). 0018 % 0019 % This algorithm is a plain adaptation of the Euclidean Nelder-Mead method 0020 % to the Riemannian setting. It comes with no convergence guarantees and 0021 % there is room for improvement. In particular, we compute centroids as 0022 % Karcher means, which seems overly expensive: cheaper forms of 0023 % average-like quantities might work better. 0024 % This solver is useful nonetheless for problems for which no derivatives 0025 % are available, and it may constitute a starting point for the development 0026 % of other Riemannian derivative-free methods. 0027 % 0028 % None of the options are mandatory. See in code for details. 0029 % 0030 % Requires problem.M.pairmean(x, y) to be defined (computes the average 0031 % between two points, x and y). 0032 % 0033 % If options.statsfun is defined, it will receive a cell of points x (the 0034 % current simplex being considered at that iteration), and, if required, 0035 % one store structure corresponding to the best point, x{1}. The points are 0036 % ordered by increasing cost: f(x{1}) <= f(x{2}) <= ... <= f(x{dim+1}), 0037 % where dim = problem.M.dim(). 0038 % 0039 % Based on http://www.optimization-online.org/DB_FILE/2007/08/1742.pdf. 0040 % 0041 % See also: manopt/solvers/pso/pso 0042 0043 % This file is part of Manopt: www.manopt.org. 0044 % Original author: Nicolas Boumal, Dec. 30, 2012. 0045 % Contributors: 0046 % Change log: 0047 % 0048 % Apr. 4, 2015 (NB): 0049 % Working with the new StoreDB class system. 0050 % Clarified interactions with statsfun and store. 0051 % 0052 % Nov. 11, 2016 (NB): 0053 % If options.verbosity is < 2, prints minimal output. 0054 % 0055 % Sep. 6, 2018 (NB): 0056 % Using retraction instead of exponential. 0057 0058 0059 % Verify that the problem description is sufficient for the solver. 0060 if ~canGetCost(problem) 0061 warning('manopt:getCost', ... 0062 'No cost provided. The algorithm will likely abort.'); 0063 end 0064 0065 % Dimension of the manifold 0066 dim = problem.M.dim(); 0067 0068 % Set local defaults here 0069 localdefaults.storedepth = 0; % no need for caching 0070 localdefaults.maxiter = max(2000, 4*dim); 0071 0072 localdefaults.reflection = 1; 0073 localdefaults.expansion = 2; 0074 localdefaults.contraction = .5; 0075 % forced to .5 to enable using pairmean functions in manifolds. 0076 % localdefaults.shrinkage = .5; 0077 0078 % Merge global and local defaults, then merge w/ user options, if any. 0079 localdefaults = mergeOptions(getGlobalDefaults(), localdefaults); 0080 if ~exist('options', 'var') || isempty(options) 0081 options = struct(); 0082 end 0083 options = mergeOptions(localdefaults, options); 0084 0085 % Start timing for initialization. 0086 timetic = tic(); 0087 0088 % If no initial simplex x is given by the user, generate one at random. 0089 if ~exist('x', 'var') || isempty(x) 0090 x = cell(dim+1, 1); 0091 for i = 1 : dim+1 0092 x{i} = problem.M.rand(); 0093 end 0094 end 0095 0096 % Create a store database and a key for each point. 0097 storedb = StoreDB(options.storedepth); 0098 key = cell(size(x)); 0099 for i = 1 : dim+1; 0100 key{i} = storedb.getNewKey(); 0101 end 0102 0103 % Compute objective-related quantities for x, and setup a 0104 % function evaluations counter. 0105 costs = zeros(dim+1, 1); 0106 for i = 1 : dim+1 0107 costs(i) = getCost(problem, x{i}, storedb, key{i}); 0108 end 0109 costevals = dim+1; 0110 0111 % Sort simplex points by cost. 0112 [costs, order] = sort(costs); 0113 x = x(order); 0114 key = key(order); 0115 0116 % Iteration counter. 0117 % At any point, iter is the number of fully executed iterations so far. 0118 iter = 0; 0119 0120 % Save stats in a struct array info, and preallocate. 0121 % savestats will be called twice for the initial iterate (number 0), 0122 % which is unfortunate, but not problematic. 0123 stats = savestats(); 0124 info(1) = stats; 0125 info(min(10000, options.maxiter+1)).iter = []; 0126 0127 % Start iterating until stopping criterion triggers. 0128 while true 0129 0130 % Make sure we don't use to much memory for the store database. 0131 storedb.purge(); 0132 0133 stats = savestats(); 0134 info(iter+1) = stats; %#ok<AGROW> 0135 iter = iter + 1; 0136 0137 % Start timing this iteration. 0138 timetic = tic(); 0139 0140 % Sort simplex points by cost. 0141 [costs, order] = sort(costs); 0142 x = x(order); 0143 key = key(order); 0144 0145 % Log / display iteration information here. 0146 if options.verbosity >= 2 0147 fprintf('Cost evals: %7d\tBest cost: %+.4e\t', ... 0148 costevals, costs(1)); 0149 end 0150 0151 % Run standard stopping criterion checks. 0152 [stop, reason] = stoppingcriterion(problem, x, options, info, iter); 0153 0154 if stop 0155 if options.verbosity >= 1 0156 fprintf([reason '\n']); 0157 end 0158 break; 0159 end 0160 0161 % Compute a centroid for the dim best points. 0162 xbar = centroid(problem.M, x(1:end-1)); 0163 0164 % Compute the direction for moving along the axis xbar - worst x. 0165 vec = problem.M.log(xbar, x{end}); 0166 0167 % Reflection step 0168 xr = problem.M.retr(xbar, vec, -options.reflection); 0169 keyr = storedb.getNewKey(); 0170 costr = getCost(problem, xr, storedb, keyr); 0171 costevals = costevals + 1; 0172 0173 % If the reflected point is honorable, drop the worst point, 0174 % replace it by the reflected point and start new iteration. 0175 if costr >= costs(1) && costr < costs(end-1) 0176 if options.verbosity >= 2 0177 fprintf('Reflection\n'); 0178 end 0179 costs(end) = costr; 0180 x{end} = xr; 0181 key{end} = keyr; 0182 continue; 0183 end 0184 0185 % If the reflected point is better than the best point, expand. 0186 if costr < costs(1) 0187 xe = problem.M.retr(xbar, vec, -options.expansion); 0188 keye = storedb.getNewKey(); 0189 coste = getCost(problem, xe, storedb, keye); 0190 costevals = costevals + 1; 0191 if coste < costr 0192 if options.verbosity >= 2 0193 fprintf('Expansion\n'); 0194 end 0195 costs(end) = coste; 0196 x{end} = xe; 0197 key{end} = keye; 0198 continue; 0199 else 0200 if options.verbosity >= 2 0201 fprintf('Reflection (failed expansion)\n'); 0202 end 0203 costs(end) = costr; 0204 x{end} = xr; 0205 key{end} = keyr; 0206 continue; 0207 end 0208 end 0209 0210 % If the reflected point is worse than the second to worst point, 0211 % contract. 0212 if costr >= costs(end-1) 0213 if costr < costs(end) 0214 % do an outside contraction 0215 xoc = problem.M.retr(xbar, vec, -options.contraction); 0216 keyoc = storedb.getNewKey(); 0217 costoc = getCost(problem, xoc, storedb, keyoc); 0218 costevals = costevals + 1; 0219 if costoc <= costr 0220 if options.verbosity >= 2 0221 fprintf('Outside contraction\n'); 0222 end 0223 costs(end) = costoc; 0224 x{end} = xoc; 0225 key{end} = keyoc; 0226 continue; 0227 end 0228 else 0229 % do an inside contraction 0230 xic = problem.M.retr(xbar, vec, options.contraction); 0231 keyic = storedb.getNewKey(); 0232 costic = getCost(problem, xic, storedb, keyic); 0233 costevals = costevals + 1; 0234 if costic <= costs(end) 0235 if options.verbosity >= 2 0236 fprintf('Inside contraction\n'); 0237 end 0238 costs(end) = costic; 0239 x{end} = xic; 0240 key{end} = keyic; 0241 continue; 0242 end 0243 end 0244 end 0245 0246 % If we get here, shrink the simplex around x{1}. 0247 if options.verbosity >= 2 0248 fprintf('Shrinkage\n'); 0249 end 0250 for i = 2 : dim+1 0251 x{i} = problem.M.pairmean(x{1}, x{i}); 0252 key{i} = storedb.getNewKey(); 0253 costs(i) = getCost(problem, x{i}, storedb, key{i}); 0254 end 0255 costevals = costevals + dim; 0256 0257 end 0258 0259 0260 info = info(1:iter); 0261 0262 % Iteration done: return only the best point found. 0263 cost = costs(1); 0264 x = x{1}; 0265 key = key{1}; 0266 0267 0268 0269 % Routine in charge of collecting the current iteration stats. 0270 function stats = savestats() 0271 stats.iter = iter; 0272 stats.cost = costs(1); 0273 stats.costevals = costevals; 0274 if iter == 0 0275 stats.time = toc(timetic); 0276 else 0277 stats.time = info(iter).time + toc(timetic); 0278 end 0279 % The statsfun can only possibly receive one store structure. We 0280 % pass the key to the best point, so that the best point's store 0281 % will be passed. But the whole cell x of points is passed through. 0282 stats = applyStatsfun(problem, x, storedb, key{1}, options, stats); 0283 end 0284 0285 end