Home > manopt > solvers > pso > pso.m

pso

PURPOSE ^

Particle swarm optimization (PSO) for derivative-free minimization.

SYNOPSIS ^

function [xbest, fbest, info, options] = pso(problem, x, options)

DESCRIPTION ^

 Particle swarm optimization (PSO) for derivative-free minimization.

 function [x, cost, info, options] = pso(problem)
 function [x, cost, info, options] = pso(problem, x0)
 function [x, cost, info, options] = pso(problem, x0, options)
 function [x, cost, info, options] = pso(problem, [], options)

 Apply the Particle Swarm Optimization minimization algorithm to
 the problem defined in the problem structure, starting with the
 population x0 if it is provided (otherwise, a random population on the
 manifold is generated). A population is a cell containing points on the
 manifold. The number of elements in the cell must match the parameter
 options.populationsize.

 To specify options whilst not specifying an initial guess, give x0 as []
 (the empty matrix).

 None of the options are mandatory. See in code for details.

 Based on the original PSO description in
   http://particleswarm.info/nn951942.ps.

 See also: manopt/solvers/neldermead/neldermead

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [xbest, fbest, info, options] = pso(problem, x, options)
0002 % Particle swarm optimization (PSO) for derivative-free minimization.
0003 %
0004 % function [x, cost, info, options] = pso(problem)
0005 % function [x, cost, info, options] = pso(problem, x0)
0006 % function [x, cost, info, options] = pso(problem, x0, options)
0007 % function [x, cost, info, options] = pso(problem, [], options)
0008 %
0009 % Apply the Particle Swarm Optimization minimization algorithm to
0010 % the problem defined in the problem structure, starting with the
0011 % population x0 if it is provided (otherwise, a random population on the
0012 % manifold is generated). A population is a cell containing points on the
0013 % manifold. The number of elements in the cell must match the parameter
0014 % options.populationsize.
0015 %
0016 % To specify options whilst not specifying an initial guess, give x0 as []
0017 % (the empty matrix).
0018 %
0019 % None of the options are mandatory. See in code for details.
0020 %
0021 % Based on the original PSO description in
0022 %   http://particleswarm.info/nn951942.ps.
0023 %
0024 % See also: manopt/solvers/neldermead/neldermead
0025 
0026 % This file is part of Manopt: www.manopt.org.
0027 % Original author: Pierre Borckmans, Dec. 30, 2012.
0028 % Contributors: Bamdev Mishra, June 18, 2014.
0029 % Change log:
0030 %
0031 %   June 18, 2014 (BM) :
0032 %       Modified for handling product manifolds. Still need overall cleanup
0033 %       to avoid potential issues, in particular wrt logarithms.
0034 %
0035 %   June 23, 2014 (NB) :
0036 %       Added some logic for handling of the populationsize option.
0037 %
0038 %   April 5, 2015 (NB):
0039 %       Working with the new StoreDB class system. The code keeps track of
0040 %       storedb keys for all points, even though it is not strictly
0041 %       necessary. This extra bookkeeping should help maintaining the code.
0042     
0043     
0044     % Verify that the problem description is sufficient for the solver.
0045     if ~canGetCost(problem)
0046         warning('manopt:getCost', ...
0047             'No cost provided. The algorithm will likely abort.');
0048     end
0049     
0050     % Dimension of the manifold
0051     dim = problem.M.dim();
0052     
0053     % Set local defaults here
0054     localdefaults.storedepth = 0;                   % no need for caching
0055     localdefaults.maxcostevals = max(5000, 2*dim);
0056     localdefaults.maxiter = max(500, 4*dim);
0057     
0058     localdefaults.populationsize = min(40, 10*dim);
0059     localdefaults.nostalgia = 1.4;
0060     localdefaults.social = 1.4;
0061     
0062     % Merge global and local defaults, then merge w/ user options, if any.
0063     localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0064     if ~exist('options', 'var') || isempty(options)
0065         options = struct();
0066     end
0067     options = mergeOptions(localdefaults, options);
0068     
0069     
0070     if ~isfield(problem.M, 'log') % BM
0071         error(['The manifold problem.M must provide a logarithmic map, ' ...
0072                'M.log(x, y). An approximate logarithm will do too.']);
0073     end
0074     
0075     % Start timing for initialization
0076     timetic = tic();
0077     
0078     % If no initial population x is given by the user,
0079     % generate one at random.
0080     if ~exist('x', 'var') || isempty(x)
0081         x = cell(options.populationsize, 1);
0082         for i = 1 : options.populationsize
0083             x{i} = problem.M.rand();
0084         end
0085     else
0086         if ~iscell(x)
0087             error('The initial guess x0 must be a cell (a population).');
0088         end
0089         if length(x) ~= options.populationsize
0090             options.populationsize = length(x);
0091             warning('manopt:pso:size', ...
0092                     ['The option populationsize was forced to the size' ...
0093                      ' of the given initial population x0.']);
0094         end
0095     end
0096     
0097     
0098     % Create a store database and a key for each point x{i}
0099     storedb = StoreDB(options.storedepth);
0100     xkey = cell(size(x));
0101     for i = 1 : numel(x)
0102         xkey{i} = storedb.getNewKey();
0103     end
0104     
0105     % Initialize personal best positions to the initial population
0106     y = x;
0107     ykey = xkey;
0108     
0109     % Save a copy of the swarm at the previous iteration
0110     xprev = x;
0111     xprevkey = xkey; %#ok<NASGU>
0112     
0113     % Initialize velocities for each particle
0114     v = cell(size(x));
0115     for i = 1 : numel(x)
0116         % random velocity to improve initial exploration
0117         v{i} = problem.M.randvec(x{i});
0118         % or null velocity
0119         % v{i} = problem.M.zerovec();
0120     end
0121     
0122     % Compute cost for each particle xi,
0123     % initialize personal best costs,
0124     % and setup a function evaluations counter.
0125     costs = zeros(size(x));
0126     for i = 1 : numel(x)
0127         costs(i) = getCost(problem, x{i}, storedb, xkey{i});
0128     end
0129     fy = costs;
0130     costevals = options.populationsize;
0131     
0132     % Identify the best particle and store its cost/position
0133     [fbest, imin] = min(costs);
0134     xbest = x{imin};
0135     xbestkey = xkey{imin}; %#ok<NASGU>
0136     
0137     % Iteration counter (at any point, iter is the number of fully executed
0138     % iterations so far)
0139     iter = 0;
0140     
0141     % Save stats in a struct array info, and preallocate.
0142     % savestats will be called twice for the initial iterate (number 0),
0143     % which is unfortunate, but not problematic.
0144     stats = savestats();
0145     info(1) = stats;
0146     info(min(10000, options.maxiter+1)).iter = [];
0147     
0148     % Start iterating until stopping criterion triggers
0149     while true
0150         
0151         stats = savestats();
0152         info(iter+1) = stats; %#ok<AGROW>
0153         iter = iter + 1;
0154         
0155         % Make sure we don't use too much memory for the store database
0156         storedb.purge();
0157         
0158         % Log / display iteration information here.
0159         if options.verbosity >= 2
0160             fprintf('Cost evals: %7d\tBest cost: %+.8e\n', costevals, fbest);
0161         end
0162         
0163         % Start timing this iteration
0164         timetic = tic();
0165         
0166         % BM: Run standard stopping criterion checks.
0167         % BM: Stop if any particle triggers a stopping criterion.
0168         for i = numel(x)
0169             [stop, reason] = stoppingcriterion(problem, x{i}, options, info, iter);
0170             if stop
0171                 break;
0172             end
0173         end
0174         
0175         if stop
0176             if options.verbosity >= 1
0177                 fprintf([reason '\n']);
0178             end
0179             break;
0180         end
0181         
0182         
0183         % Compute the inertia factor
0184         % (linearly decreasing from .9 to .4, from iter=0 to maxiter)
0185         w = 0.4 + 0.5*(1-iter/options.maxiter);
0186         
0187         % Compute velocities
0188         for i = 1 : numel(x)
0189             
0190             % Get the position and past best position of particle i
0191             xi = x{i};
0192             yi = y{i};
0193             
0194             % Get the previous position and velocity of particle i
0195             xiprev = xprev{i};
0196             vi = v{i};
0197             
0198             % Compute new velocity of particle i,
0199             % composed of 3 contributions
0200             inertia = problem.M.lincomb(xi, w , problem.M.transp(xiprev, xi, vi));
0201             nostalgia = problem.M.lincomb(xi, rand(1)*options.nostalgia, problem.M.log(xi, yi) );
0202             social = problem.M.lincomb(xi, rand(1) * options.social, problem.M.log(xi, xbest));
0203             
0204             v{i} = problem.M.lincomb(xi, 1, inertia, 1, problem.M.lincomb(xi, 1, nostalgia, 1, social));
0205             
0206         end
0207         
0208         % Backup the current swarm positions
0209         xprev = x;
0210         xprevkey = xkey; %#ok<NASGU>
0211         
0212         % Update positions, personal bests and global best
0213         for i = 1 : numel(x)
0214             % compute new position of particle i
0215             x{i} = problem.M.retr(x{i}, v{i});
0216             xkey{i} = storedb.getNewKey();
0217             % compute new cost of particle i
0218             fxi = getCost(problem, x{i}, storedb, xkey{i});
0219             costevals = costevals + 1;
0220             
0221             % update costs of the swarm
0222             costs(i) = fxi;
0223             % update self-best if necessary
0224             if fxi < fy(i)
0225                 % update self-best cost and position
0226                 fy(i) = fxi;
0227                 y{i} = x{i};
0228                 ykey{i} = xkey{i};
0229                 % update global-best if necessary
0230                 if fy(i) < fbest
0231                     fbest = fy(i);
0232                     xbest = y{i};
0233                     xbestkey = ykey{i}; %#ok<NASGU>
0234                 end
0235             end
0236         end
0237     end
0238     
0239     
0240     info = info(1:iter);
0241      
0242     % Routine in charge of collecting the current iteration stats
0243     function stats = savestats()
0244         stats.iter = iter;
0245         stats.cost = fbest;
0246         stats.costevals = costevals;
0247         stats.x = x;
0248         stats.v = v;
0249         stats.xbest = xbest;
0250         if iter == 0
0251             stats.time = toc(timetic);
0252         else
0253             stats.time = info(iter).time + toc(timetic);
0254         end
0255         
0256         % BM: Begin storing user defined stats for the entire population
0257         num_old_fields = size(fieldnames(stats), 1);
0258         trialstats = applyStatsfun(problem, x{1}, storedb, xkey{1}, options, stats);% BM
0259         new_fields = fieldnames(trialstats);
0260         num_new_fields = size(fieldnames(trialstats), 1);
0261         num_additional_fields =  num_new_fields - num_old_fields; % User has defined new fields
0262         for jj = 1 : num_additional_fields % New fields added
0263             tempfield = new_fields(num_old_fields + jj);
0264             stats.(char(tempfield)) = cell(options.populationsize, 1);
0265         end
0266         for ii = 1 : options.populationsize % Adding information for each element of the population
0267             tempstats = applyStatsfun(problem, x{ii}, storedb, xkey{ii}, options, stats);
0268             for jj = 1 : num_additional_fields
0269                 tempfield = new_fields(num_old_fields + jj);
0270                 tempfield_value = tempstats.(char(tempfield));
0271                 stats.(char(tempfield)){ii} = tempfield_value;
0272             end
0273         end
0274         % BM: End storing
0275        
0276     end
0277     
0278     
0279 end

Generated on Sat 12-Nov-2016 14:11:22 by m2html © 2005