Home > manopt > tools > diagsum.m

diagsum

PURPOSE ^

C = DIAGSUM(A, d1, d2) Performs the trace

SYNOPSIS ^

function [tracedtensor] = diagsum(tensor1, d1, d2)

DESCRIPTION ^

 C = DIAGSUM(A, d1, d2) Performs the trace
 C(i[1],...,i[d1-1],i[d1+1],...,i[d2-1],i[d2+1],...i[n]) =
              A(i[1],...,i[d1-1],k,i[d1+1],...,i[d2-1],k,i[d2+1],...,i[n])
 (Sum on k).

 C = DIAGSUM(A, d1, d2) traces A along the diagonal formed by dimensions d1
 and d2. If the lengths of these dimensions are not equal, DIAGSUM traces
 until the end of the shortest of dimensions d1 and d2 is reached. This is
 an analogue of the built in TRACE function.

 Wynton Moore, January 2006

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [tracedtensor] = diagsum(tensor1, d1, d2)
0002 % C = DIAGSUM(A, d1, d2) Performs the trace
0003 % C(i[1],...,i[d1-1],i[d1+1],...,i[d2-1],i[d2+1],...i[n]) =
0004 %              A(i[1],...,i[d1-1],k,i[d1+1],...,i[d2-1],k,i[d2+1],...,i[n])
0005 % (Sum on k).
0006 %
0007 % C = DIAGSUM(A, d1, d2) traces A along the diagonal formed by dimensions d1
0008 % and d2. If the lengths of these dimensions are not equal, DIAGSUM traces
0009 % until the end of the shortest of dimensions d1 and d2 is reached. This is
0010 % an analogue of the built in TRACE function.
0011 %
0012 % Wynton Moore, January 2006
0013 
0014 
0015 dim1=size(tensor1);
0016 numdims=length(dim1);
0017 
0018 
0019 %check inputs
0020 if d1==d2
0021     tracedtensor=squeeze(sum(tensor1,d1));
0022 elseif numdims==2
0023     tracedtensor=trace(tensor1);
0024 elseif dim1(d1)==1 && dim1(d2)==1
0025     tracedtensor=squeeze(tensor1);
0026 else
0027 
0028 
0029     %determine correct permutation
0030     swapd1=d1;swapd2=d2;
0031     
0032     if d1~=numdims-1 && d1~=numdims && d2~=numdims-1
0033         swapd1=numdims-1;
0034     elseif d1~=numdims-1 && d1~=numdims && d2~=numdims
0035         swapd1=numdims;
0036     end
0037     if d2~=numdims-1 && d2~=numdims && swapd1~=numdims-1
0038         swapd2=numdims-1;
0039     elseif d2~=numdims-1 && d2~=numdims && swapd1~=numdims
0040         swapd2=numdims;
0041     end
0042     
0043     
0044     %prepare for construction of selector tensor
0045     temp1=eye(numdims);
0046     permmatrix=temp1;
0047     permmatrix(:,d1)=temp1(:,swapd1);
0048     permmatrix(:,swapd1)=temp1(:,d1);
0049     permmatrix(:,d2)=temp1(:,swapd2);
0050     permmatrix(:,swapd2)=temp1(:,d2);
0051 
0052     selectordim=dim1*permmatrix;
0053     permvector=(1:numdims)*permmatrix;
0054 
0055 
0056     %construct selector tensor
0057     if numdims>3
0058         selector = ipermute(outer(ones(selectordim(1:numdims-2)), ...
0059                                   eye(selectordim(numdims-1), ...
0060                                       selectordim(numdims)), ...
0061                                   0), ...
0062                             permvector);
0063     else
0064         %when numdims=3, the above line gives ndims(selector)=4. This
0065         %routine avoids that error. When used with GMDMP, numdims will be
0066         %at least 4, so this routine will be unnecessary.
0067         selector2=eye(selectordim(numdims-1), selectordim(numdims));
0068         selector=zeros(selectordim);
0069         for j=1:selectordim(1)
0070             selector(j, :, :)=selector2;
0071         end
0072         selector=ipermute(selector, permvector);
0073     end
0074     
0075     
0076     %perform trace, discard resulting singleton dimensions
0077     tracedtensor=sum(sum(tensor1.*selector, d1), d2);
0078     tracedtensor=squeeze(tracedtensor);
0079     
0080     
0081 end
0082 
0083 
0084 %correction for abberation in squeeze function:
0085 %size(squeeze(rand(1,1,2)))=[2 1]
0086 nontracedimensions=dim1;
0087 nontracedimensions(d1)=[];
0088 if d2>d1
0089     nontracedimensions(d2-1)=[];
0090 else
0091     nontracedimensions(d2)=[];
0092 end
0093 tracedsize=size(tracedtensor);
0094 % Next line modified, Nicolas Boumal, April 30, 2012, such that
0095 % diagsum(A, 1, 2) would compute the trace of A, a 2D matrix.
0096 if length(tracedsize)==2 && tracedsize(2)==1 && ...
0097    (isempty(nontracedimensions) || tracedsize(1)~=nontracedimensions(1))
0098 
0099     tracedtensor=tracedtensor.';
0100     
0101 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005