0001 function grad = getGradient(problem, x, storedb, key)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 if ~exist('key', 'var')
0041 if ~exist('storedb', 'var')
0042 storedb = StoreDB();
0043 end
0044 key = storedb.getNewKey();
0045 end
0046
0047
0048
0049 store = storedb.getWithShared(key);
0050 store_is_stale = false;
0051
0052
0053
0054 force_grad_caching = true;
0055 if force_grad_caching && isfield(store, 'grad__')
0056 grad = store.grad__;
0057 return;
0058 end
0059
0060
0061
0062 cost_computed = false;
0063
0064
0065 if isfield(problem, 'grad')
0066
0067
0068
0069 switch nargin(problem.grad)
0070 case 1
0071 grad = problem.grad(x);
0072 case 2
0073 [grad, store] = problem.grad(x, store);
0074 case 3
0075
0076 grad = problem.grad(x, storedb, key);
0077
0078
0079
0080 store_is_stale = true;
0081 otherwise
0082 up = MException('manopt:getGradient:badgrad', ...
0083 'grad should accept 1, 2 or 3 inputs.');
0084 throw(up);
0085 end
0086
0087 elseif isfield(problem, 'costgrad')
0088
0089
0090
0091 switch nargin(problem.costgrad)
0092 case 1
0093 [cost, grad] = problem.costgrad(x);
0094 case 2
0095 [cost, grad, store] = problem.costgrad(x, store);
0096 case 3
0097
0098 [cost, grad] = problem.costgrad(x, storedb, key);
0099 store_is_stale = true;
0100 otherwise
0101 up = MException('manopt:getGradient:badcostgrad', ...
0102 'costgrad should accept 1, 2 or 3 inputs.');
0103 throw(up);
0104 end
0105
0106 cost_computed = true;
0107
0108 elseif canGetEuclideanGradient(problem)
0109
0110
0111 egrad = getEuclideanGradient(problem, x, storedb, key);
0112
0113 switch nargin(problem.M.egrad2rgrad)
0114 case 2
0115 grad = problem.M.egrad2rgrad(x, egrad);
0116 case 4
0117 grad = problem.M.egrad2rgrad(x, egrad, storedb, key);
0118 otherwise
0119 up = MException('manopt:getGradient:egrad2rgrad', ...
0120 'egrad2rgrad should accept 2 or 4 inputs.');
0121 throw(up);
0122 end
0123 store_is_stale = true;
0124
0125 elseif canGetPartialGradient(problem)
0126
0127
0128 d = problem.ncostterms;
0129 grad = getPartialGradient(problem, x, 1:d, storedb, key);
0130 store_is_stale = true;
0131
0132 elseif canGetDirectionalDerivative(problem)
0133
0134
0135 B = tangentorthobasis(problem.M, x);
0136 df = zeros(size(B));
0137 for k = 1 : numel(B)
0138 df(k) = getDirectionalDerivative(problem, x, B{k}, storedb, key);
0139 end
0140 grad = lincomb(problem.M, x, B, df);
0141 store_is_stale = true;
0142
0143 else
0144
0145
0146 grad = getApproxGradient(problem, x, storedb, key);
0147 store_is_stale = true;
0148
0149 end
0150
0151
0152 if store_is_stale
0153 store = storedb.getWithShared(key);
0154 end
0155
0156
0157 if force_grad_caching
0158 store.grad__ = grad;
0159 end
0160
0161
0162 if cost_computed
0163 store.cost__ = cost;
0164 end
0165
0166 storedb.setWithShared(store, key);
0167
0168 end