0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013 function [X,cost,test,stats] = completion_als( A_Omega, Omega, A_Gamma, Gamma, X, opts )
0014
0015 if ~isfield( opts, 'maxiter'); opts.maxiter = 100; end
0016 if ~isfield( opts, 'tol'); opts.tol = 1e-6; end
0017 if ~isfield( opts, 'reltol'); opts.reltol = 1e-6; end
0018
0019 n = X.size;
0020 r = X.rank;
0021 d = X.order;
0022
0023 cost = zeros(2*opts.maxiter,1);
0024 test = zeros(2*opts.maxiter,1);
0025
0026 norm_A_Omega = norm( A_Omega );
0027 norm_A_Gamma = norm( A_Gamma );
0028
0029 X = orthogonalize( X, 1 );
0030
0031 t = tic;
0032 stats.time = [0];
0033 stats.conv = false;
0034
0035 for i = 1:opts.maxiter
0036
0037
0038
0039
0040 fprintf(1,'Currently optimizing core: ')
0041 for mu = 1:d-1
0042 fprintf(1,'%i ', mu)
0043 X.U{mu} = solve_least_squares( A_Omega, Omega, X, mu );
0044 X = orth_at( X, mu, 'left' );
0045 end
0046 cost(2*i-1) = sqrt(2*func(A_Omega, X, Omega)) / norm_A_Omega;
0047
0048
0049 if cost(2*i-1) < opts.tol
0050 disp(sprintf('CONVERGED AFTER %i HALF-SWEEPS. Rel. residual smaller than %0.3g', ...
0051 2*i-1, opts.tol))
0052 stats.conv = true;
0053 cost = cost(1:2*i-1,1);
0054 stats.time = [stats.time stats.time(end)+toc(t)];
0055 test(2*i-1) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0056 test = test(1:2*i-1,1);
0057 break
0058 end
0059
0060 if i > 1
0061 reltol = abs(cost(2*i-1) - cost(2*i-2)) / cost(2*i-1);
0062 if reltol < opts.reltol
0063 disp(sprintf('No more progress in gradient change, but not converged after %i half-sweeps. ABORTING!. \nRelative change is smaller than %0.3g', ...
0064 i, opts.reltol))
0065 stats.conv = false;
0066 cost = cost(1:2*i-1,1);
0067 stats.time = [stats.time stats.time(end)+toc(t)];
0068 test(2*i-1) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0069 test = test(1:2*i-1,1);
0070 break
0071 end
0072 end
0073
0074 stats.time = [stats.time stats.time(end)+toc(t)];
0075 test(2*i-1) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0076 t = tic;
0077
0078 fprintf(1,'\nFinished forward sweep.\n Cost: %e\n Test: %e\n', cost(2*i-1), test(2*i-1) );
0079
0080
0081
0082 fprintf(1,'Currently optimizing core: ')
0083 for mu = d:-1:2
0084 fprintf(1,'%i ', mu)
0085 X.U{mu} = solve_least_squares( A_Omega, Omega, X, mu );
0086 X = orth_at( X, mu, 'right' );
0087 end
0088
0089 cost(2*i) = sqrt(2*func(A_Omega, X, Omega)) / norm_A_Omega;
0090
0091
0092 if cost(2*i) < opts.tol
0093 disp(sprintf('CONVERGED AFTER %i HALF-SWEEPS. Rel. residual smaller than %0.3g', ...
0094 2*i, opts.tol))
0095 stats.conv = true;
0096 cost = cost(1:2*i,1);
0097 stats.time = [stats.time stats.time(end)+toc(t)];
0098 test(2*i) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0099 test = test(1:2*i,1);
0100 break
0101 end
0102
0103 if i > 1
0104 reltol = abs(cost(2*i) - cost(2*i-1)) / cost(2*i);
0105 if reltol < opts.reltol
0106 disp(sprintf('No more progress in gradient change, but not converged after %i half-sweeps. ABORTING!. \nRelative change is smaller than %0.3g', ...
0107 2*i, opts.reltol))
0108 stats.conv = false;
0109 cost = cost(1:2*i,1);
0110 stats.time = [stats.time stats.time(end)+toc(t)];
0111 test(2*i) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0112 test = test(1:2*i,1);
0113 break
0114 end
0115 end
0116
0117 stats.time = [stats.time stats.time(end)+toc(t)];
0118 test(2*i) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0119 t = tic;
0120 fprintf(1,'\nFinished backward sweep.\n Cost: %e\n Test: %e\n', cost(2*i), test(2*i) );
0121
0122
0123 disp('_______________________________________________________________')
0124 end
0125
0126
0127
0128 stats.time = stats.time(2:end);
0129
0130 end
0131
0132
0133 function res = func(A_Omega, X, Omega)
0134 res = 0.5*norm( A_Omega - X(Omega) )^2;
0135 end
0136
0137
0138 function res = solve_least_squares( A_Omega, Omega, X, mu )
0139
0140 n = X.size;
0141 d = X.order;
0142 r = X.rank;
0143
0144 [jmu,idx] = sort(Omega(:,mu),'ascend');
0145 Omega = Omega(idx,:);
0146 A_Omega = A_Omega(idx);
0147
0148 C = cell(1,d);
0149 for i=1:d
0150 C{i} = permute( X.U{i}, [1 3 2]);
0151 end
0152 res = zeros( size(C{mu}) );
0153
0154
0155
0156
0157
0158
0159
0160
0161
0162
0163
0164
0165
0166
0167
0168
0169
0170
0171
0172
0173 B = als_solve_mex( n, r, C, Omega', mu)';
0174
0175 for i = 1:X.size(mu)
0176 idx = find(jmu == i);
0177
0178 if isempty(idx)
0179 error('No samples for this slice!')
0180 end
0181 res(:,:,i) = reshape(B(idx,:)\A_Omega(idx), r(mu), r(mu+1));
0182 end
0183
0184
0185 res = permute( res, [1 3 2] );
0186
0187 end