Home > manopt > core > getGradient.m

getGradient

PURPOSE ^

Computes the gradient of the cost function at x.

SYNOPSIS ^

function grad = getGradient(problem, x, storedb, key)

DESCRIPTION ^

 Computes the gradient of the cost function at x.

 function grad = getGradient(problem, x)
 function grad = getGradient(problem, x, storedb)
 function grad = getGradient(problem, x, storedb, key)

 Returns the gradient at x of the cost function described in the problem
 structure.

 storedb is a StoreDB object, key is the StoreDB key to point x.

 See also: getDirectionalDerivative canGetGradient

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function grad = getGradient(problem, x, storedb, key)
0002 % Computes the gradient of the cost function at x.
0003 %
0004 % function grad = getGradient(problem, x)
0005 % function grad = getGradient(problem, x, storedb)
0006 % function grad = getGradient(problem, x, storedb, key)
0007 %
0008 % Returns the gradient at x of the cost function described in the problem
0009 % structure.
0010 %
0011 % storedb is a StoreDB object, key is the StoreDB key to point x.
0012 %
0013 % See also: getDirectionalDerivative canGetGradient
0014 
0015 % This file is part of Manopt: www.manopt.org.
0016 % Original author: Nicolas Boumal, Dec. 30, 2012.
0017 % Contributors:
0018 % Change log:
0019 %
0020 %   April 3, 2015 (NB):
0021 %       Works with the new StoreDB class system.
0022 %
0023 %  June 28, 2016 (NB):
0024 %       Works with getPartialGradient.
0025 %
0026 %   Nov. 1, 2016 (NB):
0027 %       Added support for gradient from directional derivatives.
0028 %       Last resort is call to getApproxGradient instead of an exception.
0029 
0030     % Allow omission of the key, and even of storedb.
0031     if ~exist('key', 'var')
0032         if ~exist('storedb', 'var')
0033             storedb = StoreDB();
0034         end
0035         key = storedb.getNewKey();
0036     end
0037 
0038     
0039     if isfield(problem, 'grad')
0040     %% Compute the gradient using grad.
0041     
0042         % Check whether this function wants to deal with storedb or not.
0043         switch nargin(problem.grad)
0044             case 1
0045                 grad = problem.grad(x);
0046             case 2
0047                 % Obtain, pass along, and save the store for x.
0048                 store = storedb.getWithShared(key);
0049                 [grad, store] = problem.grad(x, store);
0050                 storedb.setWithShared(store, key);
0051             case 3
0052                 % Pass along the whole storedb (by reference), with key.
0053                 grad = problem.grad(x, storedb, key);
0054             otherwise
0055                 up = MException('manopt:getGradient:badgrad', ...
0056                     'grad should accept 1, 2 or 3 inputs.');
0057                 throw(up);
0058         end
0059     
0060     elseif isfield(problem, 'costgrad')
0061     %% Compute the gradient using costgrad.
0062         
0063         % Check whether this function wants to deal with storedb or not.
0064         switch nargin(problem.costgrad)
0065             case 1
0066                 [unused, grad] = problem.costgrad(x); %#ok
0067             case 2
0068                 % Obtain, pass along, and save the store for x.
0069                 store = storedb.getWithShared(key);
0070                 [unused, grad, store] = problem.costgrad(x, store); %#ok
0071                 storedb.setWithShared(store, key);
0072             case 3
0073                 % Pass along the whole storedb (by reference), with key.
0074                 [unused, grad] = problem.costgrad(x, storedb, key); %#ok
0075             otherwise
0076                 up = MException('manopt:getGradient:badcostgrad', ...
0077                     'costgrad should accept 1, 2 or 3 inputs.');
0078                 throw(up);
0079         end
0080     
0081     elseif canGetEuclideanGradient(problem)
0082     %% Compute the gradient using the Euclidean gradient.
0083         
0084         egrad = getEuclideanGradient(problem, x, storedb, key);
0085         grad = problem.M.egrad2rgrad(x, egrad);
0086     
0087     elseif canGetPartialGradient(problem)
0088     %% Compute the gradient using a full partial gradient.
0089         
0090         d = problem.ncostterms;
0091         grad = getPartialGradient(problem, x, 1:d, storedb, key);
0092         
0093     elseif canGetDirectionalDerivative(problem)
0094     %% Compute gradient based on directional derivatives; expensive!
0095     
0096         B = tangentorthobasis(problem.M, x);
0097         df = zeros(size(B));
0098         for k = 1 : numel(B)
0099             df(k) = getDirectionalDerivative(problem, x, B{k}, storedb, key);
0100         end
0101         grad = lincomb(problem.M, x, B, df);
0102 
0103     else
0104     %% Attempt the computation of an approximation of the gradient.
0105         
0106         grad = getApproxGradient(problem, x, storedb, key);
0107         
0108     end
0109     
0110 end

Generated on Sat 12-Nov-2016 14:11:22 by m2html © 2005