Home > examples > low_rank_tensor_completion.m

low_rank_tensor_completion

PURPOSE ^

Given partial observation of a low rank tensor, attempts to complete it.

SYNOPSIS ^

function low_rank_tensor_completion()

DESCRIPTION ^

 Given partial observation of a low rank tensor, attempts to complete it.

 function low_rank_tensor_completion()

 This example demonstrates how to use the geometry factory for the
 quotient manifold of fixed-rank tensors, 
 fixedrankfactory_tucker_preconditioned.

 This geometry is described in the technical report
 "Riemannian preconditioning for tensor completion"
 Hiroyuki Kasai and Bamdev Mishra, arXiv:1506.02159, 2015.

 This can be a starting point for many optimization problems of the form:

 minimize f(X) such that rank(X) = [r1 r2 r3], size(X) = [n1, n2, n3].

 Input:  None. This example file generates random data.
 
 Output: None.

 Please cite the Manopt paper as well as the research paper:
     @Techreport{kasai2015,
       Title   = {{R}iemannian preconditioning for tensor completion},
       Author  = {Kasai, H. and Mishra, B.},
       Journal = {Arxiv preprint arXiv:1506.02159},
       Year    = {2015}
     }

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function low_rank_tensor_completion()
0002 % Given partial observation of a low rank tensor, attempts to complete it.
0003 %
0004 % function low_rank_tensor_completion()
0005 %
0006 % This example demonstrates how to use the geometry factory for the
0007 % quotient manifold of fixed-rank tensors,
0008 % fixedrankfactory_tucker_preconditioned.
0009 %
0010 % This geometry is described in the technical report
0011 % "Riemannian preconditioning for tensor completion"
0012 % Hiroyuki Kasai and Bamdev Mishra, arXiv:1506.02159, 2015.
0013 %
0014 % This can be a starting point for many optimization problems of the form:
0015 %
0016 % minimize f(X) such that rank(X) = [r1 r2 r3], size(X) = [n1, n2, n3].
0017 %
0018 % Input:  None. This example file generates random data.
0019 %
0020 % Output: None.
0021 %
0022 % Please cite the Manopt paper as well as the research paper:
0023 %     @Techreport{kasai2015,
0024 %       Title   = {{R}iemannian preconditioning for tensor completion},
0025 %       Author  = {Kasai, H. and Mishra, B.},
0026 %       Journal = {Arxiv preprint arXiv:1506.02159},
0027 %       Year    = {2015}
0028 %     }
0029 
0030 % This file is part of Manopt and is copyrighted. See the license file.
0031 %
0032 % Main authors: Hiroyuki Kasai and Bamdev Mishra, June 16, 2015.
0033 % Contributors:
0034 %
0035 % Change log:
0036 %
0037     
0038 
0039     % Random data generation with pseudo-random numbers from a
0040     % uniform distribution on [0, 1].
0041     % First, choose the size of the problem.
0042     % We will complete a tensor of size n1-by-n2-by-n3 of rank (r1, r2, r3):
0043     n1 = 70;
0044     n2 = 60;
0045     n3 = 50;
0046     r1 = 3;
0047     r2 = 4;
0048     r3 = 5;
0049     tensor_dims = [n1 n2 n3];
0050     core_dims = [r1 r2 r3];
0051     total_entries = n1*n2*n3;
0052     
0053     % Generate a random tensor A of size n1-by-n2-by-n3 of rank (r1, r2, r3).
0054     [U1,R1] = qr(rand(n1, r1), 0);
0055     [U2,R2] = qr(rand(n2, r2), 0);
0056     [U3,R3] = qr(rand(n3, r3), 0);
0057 
0058     Z.U1 = R1;
0059     Z.U2 = R2;
0060     Z.U3 = R3;   
0061     Z.G = rand( core_dims );
0062     Core = tucker2multiarray(Z); % Converts tucker format tensor to full tensor.
0063 
0064     Y.U1 = U1;
0065     Y.U2 = U2;
0066     Y.U3 = U3;
0067     Y.G = Core;
0068     A = tucker2multiarray(Y);       
0069     
0070     % Generate a random mask P for observed entries: P(i, j, k) = 1 if the entry
0071     % (i, j, k) of A is observed, and 0 otherwise.
0072     % Observation ratio
0073     fraction = 0.1; % Fraction of known entries.
0074     nr = round(fraction * total_entries);
0075     ind = randperm(total_entries);
0076     ind = ind(1 : nr);
0077     P = false(tensor_dims);
0078     P(ind) = true;    
0079     % Hence, we know the nonzero entries in PA:
0080     PA = P.*A;  
0081     
0082 
0083     
0084     
0085     % Pick the manifold of tensors of size n1-by-n2-by-n3 of rank (r1, r2, r3).
0086     problem.M = fixedrankfactory_tucker_preconditioned(tensor_dims, core_dims);
0087     
0088     
0089     
0090     
0091     % Define the problem cost function. The input X is a structure with
0092     % fields U1, U2, U3, G representing a rank (r1,r2,r3) tensor.
0093     % f(X) = 1/2 * || P.*(X - A) ||^2
0094     problem.cost = @cost;
0095     function f = cost(X)
0096         Xmultiarray = tucker2multiarray(X);
0097         Diffmultiarray = P.*Xmultiarray - PA;
0098         Diffmultiarray_flat = reshape(Diffmultiarray, n1, n2*n3);
0099         f = .5*norm(Diffmultiarray_flat , 'fro')^2;
0100     end
0101 
0102 
0103     
0104     
0105     % Define the Euclidean gradient of the cost function, that is, the
0106     % gradient of f(X) seen as a standard function of X.
0107     % nabla f(X) = P.*(X-A)
0108     % We only need to give the Euclidean gradient. Manopt converts it
0109     % internally to the Riemannian counterpart.
0110     problem.egrad =  @egrad;
0111     function [g] = egrad(X)
0112         Xmultiarray = tucker2multiarray(X);
0113         Smultiarray = P.*Xmultiarray - PA;     
0114 
0115         % BM: computation of S, S1, S2, and S3
0116         S1multiarray = reshape(Smultiarray, [n1, n2*n3]);
0117         S2multiarray = reshape(permute(Smultiarray, [2 1 3]),[n2, n1*n3]);
0118         S3multiarray = reshape(permute(Smultiarray, [3 1 2]),[n3, n1*n2]);
0119 
0120         g.U1 = double(S1multiarray) * kron(X.U3, X.U2) * reshape(X.G, r1, r2*r3)';
0121         g.U2 = double(S2multiarray) * kron(X.U3, X.U1) * reshape(permute(X.G, [2 1 3]), r2, r1*r3)';
0122         g.U3 = double(S3multiarray) * kron(X.U2, X.U1) * reshape(permute(X.G, [3 1 2]), r3, r1*r2)';
0123         g.G = reshape(X.U1' * reshape(double(Smultiarray),n1,n2*n3) * kron(X.U3', X.U2')', r1, r2, r3);  
0124     end
0125     
0126     
0127     
0128     
0129     
0130     % Define the Euclidean Hessian of the cost at X, along eta, where eta is
0131     % represented as a tangent vector: a structure with fields U1, U2, U3, G.
0132     % This is the directional derivative of nabla f(X) at X along Xdot:
0133     % nabla^2 f(X)[Xdot] = P.*Xdot
0134     % We only need to give the Euclidean Hessian. Manopt converts it
0135     % internally to the Riemannian counterpart.
0136     problem.ehess = @ehess;
0137     function [Hess] = ehess(X, eta)
0138 
0139         % Computing S, and its unfolding matrices, S1, S2, and S3.
0140         Xmultiarray = tucker2multiarray(X);
0141         S = P.*Xmultiarray - PA;     
0142         S1 = reshape(S, [n1, n2*n3]);
0143         S2 = reshape(permute(S, [2 1 3]),[n2, n1*n3]);
0144         S3 = reshape(permute(S, [3 1 2]),[n3, n1*n2]);            
0145 
0146         % Computing Sdot, S1dot, S2dot and S3dot.
0147         XG = X.G;
0148         etaG = eta.G;
0149         G1 = zeros(4*size(X.G));
0150         G1(1:r1, 1:r2, 1:r3) = XG;
0151         G1(r1 + 1 : 2*r1, r2 + 1 : 2*r2, r3 + 1 : 2*r3) = XG;
0152         G1(2*r1 + 1 : 3*r1, 2*r2 + 1 : 3*r2, 2*r3 + 1 : 3*r3) = XG;
0153         G1(3*r1 + 1 : 4*r1, 3*r2 + 1 : 4*r2, 3*r3 + 1 : 4*r3) = etaG;      
0154              
0155         X1.G = G1;
0156         X1.U1 = [eta.U1 X.U1 X.U1 X.U1];
0157         X1.U2 = [X.U2 eta.U2 X.U2 X.U2];
0158         X1.U3 = [X.U3 X.U3 eta.U3 X.U3];
0159         
0160         X1multiarray = tucker2multiarray(X1);
0161         Sdot = P.*X1multiarray;
0162         S1dot = reshape(Sdot, [n1, n2*n3]);
0163         S2dot = reshape(permute(Sdot, [2 1 3]),[n2, n1*n3]);
0164         S3dot = reshape(permute(Sdot, [3 1 2]),[n3, n1*n2]);
0165         
0166         % Computing unfolding matrices of X.G and eta.G.
0167         X_G1 = reshape(double(X.G),r1, r2*r3);
0168         X_G2 = reshape(permute(double(X.G),[2 1 3]),r2, r1*r3);
0169         X_G3 = reshape(permute(double(X.G),[3 1 2]),r3, r1*r2);
0170         eta_G1 = reshape(double(eta.G),r1, r2*r3);
0171         eta_G2 = reshape(permute(double(eta.G),[2 1 3]),r2, r1*r3);
0172         eta_G3 = reshape(permute(double(eta.G),[3 1 2]),r3, r1*r2);             
0173 
0174         % Computing Hessians for U1, U2 and U3.
0175         T1 = double(S1dot) * (kron(X.U3,X.U2)*X_G1') ...
0176             + double(S1) * (kron(eta.U3,X.U2)*X_G1' ...
0177             + kron(X.U3,eta.U2)*X_G1' + kron(X.U3,X.U2)*eta_G1');
0178         
0179         T2 = double(S2dot) * (kron(X.U3,X.U1)*X_G2') ...
0180             + double(S2) * (kron(eta.U3,X.U1)*X_G2' ...
0181             + kron(X.U3,eta.U1)*X_G2' + kron(X.U3,X.U1)*eta_G2');
0182 
0183         T3 = double(S3dot) * (kron(X.U2,X.U1)*X_G3') ...
0184             + double(S3) * (kron(eta.U2,X.U1)*X_G3' ...
0185             + kron(X.U2,eta.U1)*X_G3' + kron(X.U2,X.U1)*eta_G3');
0186         
0187         Hess.U1 = T1;
0188         Hess.U2 = T2;
0189         Hess.U3 = T3;  
0190         
0191         % Computing Hessian for G
0192         N.U1 = X.U1';
0193         N.U2 = X.U2';
0194         N.U3 = X.U3';
0195         N.G = Sdot;
0196         M0array = tucker2multiarray(N);
0197         
0198         M1.U1 = eta.U1';
0199         M1.U2 = X.U2';
0200         M1.U3 = X.U3';
0201         M1.G = S;    
0202         M1array = tucker2multiarray(M1);
0203         
0204         M2.U1 = X.U1';
0205         M2.U2 = eta.U2';
0206         M2.U3 = X.U3';
0207         M2.G = S;    
0208         M2array = tucker2multiarray(M2); 
0209         
0210         M3.U1 = X.U1';
0211         M3.U2 = X.U2';
0212         M3.U3 = eta.U3';
0213         M3.G = S;    
0214         M3array = tucker2multiarray(M3);   
0215         
0216         Hess.G = M0array + M1array + M2array + M3array; 
0217     end
0218     
0219 
0220  
0221 
0222     % Check consistency of the gradient and the Hessian. Useful if you
0223     % adapt this example for a new cost function and you would like to make
0224     % sure there is no mistake.
0225     %
0226     % Notice that the checkhessian test fails: the slope is not right.
0227     % This is because the retraction is not second-order compatible with
0228     % the Riemannian exponential on this manifold, making
0229     % the checkhessian tool unusable. The Hessian is correct though.
0230     % % warning('off', 'manopt:fixedrankfactory_tucker_preconditioned:exp');
0231     % % checkgradient(problem);
0232     % % drawnow;
0233     % % pause;
0234     % % checkhessian(problem);
0235     % % drawnow;
0236     % % pause;
0237     
0238 
0239     
0240     % options
0241     options.maxiter = 200;
0242     options.maxinner = 30;
0243     options.maxtime = inf;
0244     options.tolgradnorm = 1e-5;     
0245 
0246 
0247     
0248     
0249     % Minimize the cost function using Riemannian trust-regions
0250     Xtr = trustregions(problem, [], options);
0251 
0252     
0253     
0254     % The reconstructed tensor is X, represented as a structure with fields
0255     % U1, U2, U3 and G.
0256     Xtrmultiarray = tucker2multiarray(Xtr);
0257     fprintf('||X-A||_F = %g\n', norm(reshape(Xtrmultiarray - A, [n1 n2*n3]), 'fro'));   
0258     
0259    
0260     
0261     
0262     % Alternatively, we could decide to use a solver such as steepestdescent (SD)
0263     % or conjugategradient (CG). These solvers need to solve a
0264     % line-search problem at each iteration. Standard line searches in
0265     % Manopt have generic purpose systems to do this. But for the problem
0266     % at hand, we could exploit the least-squares structure to compute an
0267     % approximate stepsize for the line-search problem. The approximation
0268     % is obtained by linearizing the nonlinear manifold locally and further
0269     % approximating it with a degree 2 polynomial approximation.
0270     % The specific derivation is in the paper referenced above.
0271     
0272     problem.linesearch = @linesearch_helper;
0273     function tmin = linesearch_helper(X, eta)
0274         
0275         % term0
0276         Xmultiarray = tucker2multiarray(X);
0277         residual_mat = P.*Xmultiarray - PA;     
0278         residual_vec = residual_mat(:);
0279         term0 = residual_vec;
0280         
0281         % term1
0282         XG = X.G;
0283         etaG = eta.G;        
0284         G1 = zeros(4*size(X.G));
0285         G1(1:r1, 1:r2, 1:r3) = XG;
0286         G1(r1 + 1 : 2*r1, r2 + 1 : 2*r2, r3 + 1 : 2*r3) = XG;
0287         G1(2*r1 + 1 : 3*r1, 2*r2 + 1 : 3*r2, 2*r3 + 1 : 3*r3) = XG;
0288         G1(3*r1 + 1 : 4*r1, 3*r2 + 1 : 4*r2, 3*r3 + 1 : 4*r3) = etaG;  
0289 
0290         X1.U1 = [eta.U1 X.U1 X.U1 X.U1];
0291         X1.U2 = [X.U2 eta.U2 X.U2 X.U2];
0292         X1.U3 = [X.U3 X.U3 eta.U3 X.U3];
0293         X1.G = G1;
0294         
0295         X1multiarray = tucker2multiarray(X1);
0296         term1_mat = P.*X1multiarray;    
0297         term1 = term1_mat(:);
0298         
0299         % tmin is the solution to the problem argmin a2*t^2 + a1*t, where
0300         % the coefficients a1 and a2 are shown below.
0301         a2 = (term1'*term1);
0302         a1 = 2*(term1'*term0);
0303         tmin = - 0.5*(a1 / a2);
0304         
0305     end    
0306 
0307     % Notice that for this solver, the Hessian is not needed.
0308     [Xcg, costcg, infocg] = conjugategradient(problem, [], options);
0309     
0310     fprintf('Take a look at the options that CG used:\n');
0311     disp(options);
0312     fprintf('And see how many trials were made at each line search call:\n');
0313     info_ls = [infocg.linesearch];
0314     disp([info_ls.costevals]); 
0315     
0316     
0317      
0318     fprintf('Try it again without the linesearch helper.\n');
0319     
0320     % Remove the linesearch helper from the problem structure.
0321     problem = rmfield(problem, 'linesearch');
0322     
0323     [Xcg, xcost, info, options] = conjugategradient(problem, []); %#ok<ASGLU>
0324     
0325     fprintf('Take a look at the options that CG used:\n');
0326     disp(options);
0327     fprintf('And see how many trials were made at each line search call:\n');
0328     info_ls = [info.linesearch];
0329     disp([info_ls.costevals]);
0330     
0331     
0332     
0333 end

Generated on Sat 12-Nov-2016 14:11:22 by m2html © 2005