0001 function [xbest, fbest, info, options] = pso(problem, x, options)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045 if ~canGetCost(problem)
0046 warning('manopt:getCost', ...
0047 'No cost provided. The algorithm will likely abort.');
0048 end
0049
0050
0051 dim = problem.M.dim();
0052
0053
0054 localdefaults.storedepth = 0;
0055 localdefaults.maxiter = max(500, 4*dim);
0056
0057 localdefaults.populationsize = min(40, 10*dim);
0058 localdefaults.nostalgia = 1.4;
0059 localdefaults.social = 1.4;
0060
0061
0062 localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0063 if ~exist('options', 'var') || isempty(options)
0064 options = struct();
0065 end
0066 options = mergeOptions(localdefaults, options);
0067
0068
0069 if ~isfield(problem.M, 'log')
0070 error(['The manifold problem.M must provide a logarithmic map, ' ...
0071 'M.log(x, y). An approximate logarithm will do too.']);
0072 end
0073
0074
0075 timetic = tic();
0076
0077
0078
0079 if ~exist('x', 'var') || isempty(x)
0080 x = cell(options.populationsize, 1);
0081 for i = 1 : options.populationsize
0082 x{i} = problem.M.rand();
0083 end
0084 else
0085 if ~iscell(x)
0086 error('The initial guess x0 must be a cell (a population).');
0087 end
0088 if length(x) ~= options.populationsize
0089 options.populationsize = length(x);
0090 warning('manopt:pso:size', ...
0091 ['The option populationsize was forced to the size' ...
0092 ' of the given initial population x0.']);
0093 end
0094 end
0095
0096
0097
0098 storedb = StoreDB(options.storedepth);
0099 xkey = cell(size(x));
0100 for i = 1 : numel(x)
0101 xkey{i} = storedb.getNewKey();
0102 end
0103
0104
0105 y = x;
0106 ykey = xkey;
0107
0108
0109 xprev = x;
0110 xprevkey = xkey;
0111
0112
0113 v = cell(size(x));
0114 for i = 1 : numel(x)
0115
0116 v{i} = problem.M.randvec(x{i});
0117
0118
0119 end
0120
0121
0122
0123
0124 costs = zeros(size(x));
0125 for i = 1 : numel(x)
0126 costs(i) = getCost(problem, x{i}, storedb, xkey{i});
0127 end
0128 fy = costs;
0129 costevals = options.populationsize;
0130
0131
0132 [fbest, imin] = min(costs);
0133 xbest = x{imin};
0134 xbestkey = xkey{imin};
0135
0136
0137
0138 iter = 0;
0139
0140
0141
0142
0143 stats = savestats();
0144 info(1) = stats;
0145 info(min(10000, options.maxiter+1)).iter = [];
0146
0147
0148 while true
0149
0150 stats = savestats();
0151 info(iter+1) = stats;
0152 iter = iter + 1;
0153
0154
0155 storedb.purge();
0156
0157
0158 if options.verbosity >= 2
0159 fprintf('Cost evals: %7d\tBest cost: %+.8e\n', costevals, fbest);
0160 end
0161
0162
0163 timetic = tic();
0164
0165
0166
0167 for i = numel(x)
0168 [stop, reason] = stoppingcriterion(problem, x{i}, options, info, iter);
0169 if stop
0170 break;
0171 end
0172 end
0173
0174 if stop
0175 if options.verbosity >= 1
0176 fprintf([reason '\n']);
0177 end
0178 break;
0179 end
0180
0181
0182
0183
0184 w = 0.4 + 0.5*(1-iter/options.maxiter);
0185
0186
0187 for i = 1 : numel(x)
0188
0189
0190 xi = x{i};
0191 yi = y{i};
0192
0193
0194 xiprev = xprev{i};
0195 vi = v{i};
0196
0197
0198
0199 inertia = problem.M.lincomb(xi, w , problem.M.transp(xiprev, xi, vi));
0200 nostalgia = problem.M.lincomb(xi, rand(1)*options.nostalgia, problem.M.log(xi, yi) );
0201 social = problem.M.lincomb(xi, rand(1) * options.social, problem.M.log(xi, xbest));
0202
0203 v{i} = problem.M.lincomb(xi, 1, inertia, 1, problem.M.lincomb(xi, 1, nostalgia, 1, social));
0204
0205 end
0206
0207
0208 xprev = x;
0209 xprevkey = xkey;
0210
0211
0212 for i = 1 : numel(x)
0213
0214 x{i} = problem.M.retr(x{i}, v{i});
0215 xkey{i} = storedb.getNewKey();
0216
0217 fxi = getCost(problem, x{i}, storedb, xkey{i});
0218 costevals = costevals + 1;
0219
0220
0221 costs(i) = fxi;
0222
0223 if fxi < fy(i)
0224
0225 fy(i) = fxi;
0226 y{i} = x{i};
0227 ykey{i} = xkey{i};
0228
0229 if fy(i) < fbest
0230 fbest = fy(i);
0231 xbest = y{i};
0232 xbestkey = ykey{i};
0233 end
0234 end
0235 end
0236 end
0237
0238
0239 info = info(1:iter);
0240
0241
0242 function stats = savestats()
0243 stats.iter = iter;
0244 stats.cost = fbest;
0245 stats.costevals = costevals;
0246 stats.x = x;
0247 stats.v = v;
0248 stats.xbest = xbest;
0249 if iter == 0
0250 stats.time = toc(timetic);
0251 else
0252 stats.time = info(iter).time + toc(timetic);
0253 end
0254
0255
0256 num_old_fields = size(fieldnames(stats), 1);
0257 trialstats = applyStatsfun(problem, x{1}, storedb, xkey{1}, options, stats);
0258 new_fields = fieldnames(trialstats);
0259 num_new_fields = size(fieldnames(trialstats), 1);
0260 num_additional_fields = num_new_fields - num_old_fields;
0261 for jj = 1 : num_additional_fields
0262 tempfield = new_fields(num_old_fields + jj);
0263 stats.(char(tempfield)) = cell(options.populationsize, 1);
0264 end
0265 for ii = 1 : options.populationsize
0266 tempstats = applyStatsfun(problem, x{ii}, storedb, xkey{ii}, options, stats);
0267 for jj = 1 : num_additional_fields
0268 tempfield = new_fields(num_old_fields + jj);
0269 tempfield_value = tempstats.(char(tempfield));
0270 stats.(char(tempfield)){ii} = tempfield_value;
0271 end
0272 end
0273
0274
0275 end
0276
0277
0278 end