Computes the gradient of a subset of terms in the cost function at x. function grad = getPartialGradient(problem, x, I) function grad = getPartialGradient(problem, x, I, storedb) function grad = getPartialGradient(problem, x, I, storedb, key) Assume the cost function described in the problem structure is a sum of many terms, as f(x) = sum_i f_i(x) for i = 1:d,
0001 function grad = getPartialGradient(problem, x, I, storedb, key) 0002 % Computes the gradient of a subset of terms in the cost function at x. 0003 % 0004 % function grad = getPartialGradient(problem, x, I) 0005 % function grad = getPartialGradient(problem, x, I, storedb) 0006 % function grad = getPartialGradient(problem, x, I, storedb, key) 0007 % 0008 % Assume the cost function described in the problem structure is a sum of 0009 % many terms, as 0010 % 0011 % f(x) = sum_i f_i(x) for i = 1:d, 0012 0013 % where d is specified as d = problem.ncostterms. 0014 % 0015 % For a subset I of 1:d, getPartialGradient obtains the gradient of the 0016 % partial cost function 0017 % 0018 % f_I(x) = sum_i f_i(x) for i = I. 0019 % 0020 % storedb is a StoreDB object, key is the StoreDB key to point x. 0021 % 0022 % See also: getGradient canGetPartialGradient getPartialEuclideanGradient 0023 0024 % This file is part of Manopt: www.manopt.org. 0025 % Original author: Nicolas Boumal, June 28, 2016 0026 % Contributors: 0027 % Change log: 0028 % 0029 % Feb. 10, 2020 (NB): 0030 % Allowing M.egrad2rgrad to take (storedb, key) as extra inputs. 0031 0032 0033 % Allow omission of the key, and even of storedb. 0034 if ~exist('key', 'var') 0035 if ~exist('storedb', 'var') 0036 storedb = StoreDB(); 0037 end 0038 key = storedb.getNewKey(); 0039 end 0040 0041 0042 % Make sure I is a row vector, so that it is natural to loop over it 0043 % with " for i = I ". 0044 I = (I(:)).'; 0045 0046 0047 if isfield(problem, 'partialgrad') 0048 %% Compute the partial gradient using partialgrad. 0049 0050 % Check whether this function wants to deal with storedb or not. 0051 switch nargin(problem.partialgrad) 0052 case 2 0053 grad = problem.partialgrad(x, I); 0054 case 3 0055 % Obtain, pass along, and save the store for x. 0056 store = storedb.getWithShared(key); 0057 [grad, store] = problem.partialgrad(x, I, store); 0058 storedb.setWithShared(store, key); 0059 case 4 0060 % Pass along the whole storedb (by reference), with key. 0061 grad = problem.partialgrad(x, I, storedb, key); 0062 otherwise 0063 up = MException('manopt:getPartialGradient:badpartialgrad', ... 0064 'partialgrad should accept 2, 3 or 4 inputs.'); 0065 throw(up); 0066 end 0067 0068 elseif canGetPartialEuclideanGradient(problem) 0069 %% Compute the partial gradient using the Euclidean partial gradient. 0070 0071 egrad = getPartialEuclideanGradient(problem, x, I, storedb, key); 0072 % Convert to the Riemannian gradient 0073 switch nargin(problem.M.egrad2rgrad) 0074 case 2 0075 grad = problem.M.egrad2rgrad(x, egrad); 0076 case 4 0077 grad = problem.M.egrad2rgrad(x, egrad, storedb, key); 0078 otherwise 0079 up = MException('manopt:getPartialGradient:egrad2rgrad', ... 0080 'egrad2rgrad should accept 2 or 4 inputs.'); 0081 throw(up); 0082 end 0083 0084 else 0085 %% Abandon computing the partial gradient. 0086 0087 up = MException('manopt:getPartialGradient:fail', ... 0088 ['The problem description is not explicit enough to ' ... 0089 'compute the partial gradient of the cost.']); 0090 throw(up); 0091 0092 end 0093 0094 end