Home > manopt > core > getGradient.m

getGradient

PURPOSE ^

Computes the gradient of the cost function at x.

SYNOPSIS ^

function grad = getGradient(problem, x, storedb, key)

DESCRIPTION ^

 Computes the gradient of the cost function at x.

 function grad = getGradient(problem, x)
 function grad = getGradient(problem, x, storedb)
 function grad = getGradient(problem, x, storedb, key)

 Returns the gradient at x of the cost function described in the problem
 structure.

 storedb is a StoreDB object, key is the StoreDB key to point x.

 See also: getDirectionalDerivative canGetGradient

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function grad = getGradient(problem, x, storedb, key)
0002 % Computes the gradient of the cost function at x.
0003 %
0004 % function grad = getGradient(problem, x)
0005 % function grad = getGradient(problem, x, storedb)
0006 % function grad = getGradient(problem, x, storedb, key)
0007 %
0008 % Returns the gradient at x of the cost function described in the problem
0009 % structure.
0010 %
0011 % storedb is a StoreDB object, key is the StoreDB key to point x.
0012 %
0013 % See also: getDirectionalDerivative canGetGradient
0014 
0015 % This file is part of Manopt: www.manopt.org.
0016 % Original author: Nicolas Boumal, Dec. 30, 2012.
0017 % Contributors:
0018 % Change log:
0019 %
0020 %   April 3, 2015 (NB):
0021 %       Works with the new StoreDB class system.
0022 %
0023 %  June 28, 2016 (NB):
0024 %       Works with getPartialGradient.
0025 %
0026 %   Nov. 1, 2016 (NB):
0027 %       Added support for gradient from directional derivatives.
0028 %       Last resort is call to getApproxGradient instead of an exception.
0029 %
0030 %   Sep. 6, 2018 (NB):
0031 %       The gradient is now cached by default. This is made practical by
0032 %       the new storedb 'remove' functionalities that keep the number of
0033 %       cached points down to a minimum. If the gradient is obtained via
0034 %       costgrad, the cost is also cached.
0035 
0036     % Allow omission of the key, and even of storedb.
0037     if ~exist('key', 'var')
0038         if ~exist('storedb', 'var')
0039             storedb = StoreDB();
0040         end
0041         key = storedb.getNewKey();
0042     end
0043 
0044     % Contrary to most similar functions, here, we get the store by
0045     % default. This is for the caching functionality described below.
0046     store = storedb.getWithShared(key);
0047     store_is_stale = false;
0048 
0049     % If the gradient has been computed before at this point (and its
0050     % memory is still in storedb), then we just look up the value.
0051     force_grad_caching = true;
0052     if force_grad_caching && isfield(store, 'grad__')
0053         grad = store.grad__;
0054         return;
0055     end
0056     
0057     
0058     if isfield(problem, 'grad')
0059     %% Compute the gradient using grad.
0060     
0061         % Check whether this function wants to deal with storedb or not.
0062         switch nargin(problem.grad)
0063             case 1
0064                 grad = problem.grad(x);
0065             case 2
0066                 [grad, store] = problem.grad(x, store);
0067             case 3
0068                 % Pass along the whole storedb (by reference), with key.
0069                 grad = problem.grad(x, storedb, key);
0070                 % The store structure in storedb might have been modified
0071                 % (since it is passed by reference), so before caching
0072                 % we'll have to update (see below).
0073                 store_is_stale = true;
0074             otherwise
0075                 up = MException('manopt:getGradient:badgrad', ...
0076                     'grad should accept 1, 2 or 3 inputs.');
0077                 throw(up);
0078         end
0079     
0080     elseif isfield(problem, 'costgrad')
0081     %% Compute the gradient using costgrad.
0082     
0083         % Check whether this function wants to deal with storedb or not.
0084         switch nargin(problem.costgrad)
0085             case 1
0086                 [cost, grad] = problem.costgrad(x);
0087             case 2
0088                 [cost, grad, store] = problem.costgrad(x, store);
0089             case 3
0090                 % Pass along the whole storedb (by reference), with key.
0091                 [cost, grad] = problem.costgrad(x, storedb, key);
0092                 store_is_stale = true;
0093             otherwise
0094                 up = MException('manopt:getGradient:badcostgrad', ...
0095                     'costgrad should accept 1, 2 or 3 inputs.');
0096                 throw(up);
0097         end
0098     
0099     elseif canGetEuclideanGradient(problem)
0100     %% Compute the Riemannian gradient using the Euclidean gradient.
0101         
0102         egrad = getEuclideanGradient(problem, x, storedb, key);
0103         grad = problem.M.egrad2rgrad(x, egrad);
0104         store_is_stale = true;
0105     
0106     elseif canGetPartialGradient(problem)
0107     %% Compute the gradient using a full partial gradient.
0108         
0109         d = problem.ncostterms;
0110         grad = getPartialGradient(problem, x, 1:d, storedb, key);
0111         store_is_stale = true;
0112         
0113     elseif canGetDirectionalDerivative(problem)
0114     %% Compute gradient based on directional derivatives; expensive!
0115     
0116         B = tangentorthobasis(problem.M, x);
0117         df = zeros(size(B));
0118         for k = 1 : numel(B)
0119             df(k) = getDirectionalDerivative(problem, x, B{k}, storedb, key);
0120         end
0121         grad = lincomb(problem.M, x, B, df);
0122         store_is_stale = true;
0123 
0124     else
0125     %% Attempt the computation of an approximation of the gradient.
0126         
0127         grad = getApproxGradient(problem, x, storedb, key);
0128         store_is_stale = true;
0129         
0130     end
0131 
0132     % If we are not sure that the store structure is up to date, update.
0133     if store_is_stale
0134         store = storedb.getWithShared(key);
0135     end
0136     
0137     % Cache here.
0138     if force_grad_caching
0139         store.grad__ = grad; 
0140     end
0141     % If we got the gradient via costgrad, then the cost has also been
0142     % computed and we can cache it.
0143     if exist('cost', 'var')
0144         store.cost__ = cost;
0145     end
0146 
0147     storedb.setWithShared(store, key);
0148     
0149 end

Generated on Mon 10-Sep-2018 11:48:06 by m2html © 2005