// mex mexUpdateEta.cpp -largeArrayDims typedef ptrdiff_t intt; // Argument 0: V, whose dim is p x r // 1: Gsq, whose dim is p x ng (SPARSE) // 2: params //------------------------------------------------------------------------- void parse_params(const mxArray *params, double *epsilon, double *normparam); //------------------------------------------------------------------------- void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) { double *V, *Eta, *inv_Zeta, *Gsq; intt *G_ir, *G_jc;; int p = mxGetM(prhs[0]), r = mxGetN(prhs[0]), ng = mxGetN(prhs[1]); V = (double*)mxGetPr(prhs[0]); Gsq = (double*)mxGetPr(prhs[1]); G_ir = (intt*)mxGetIr(prhs[1]); G_jc = (intt*)mxGetJc(prhs[1]); //----------------------------------------------------------------------- double epsilon, normparam; parse_params(prhs[2], &epsilon, &normparam); const double power1 = (2.0-normparam)/2.0, power2 = normparam/2.0, power3 = (normparam-1.0)/normparam; //----------------------------------------------------------------------- plhs[0] = mxCreateDoubleMatrix(p, r, mxREAL); inv_Zeta = (double*)mxGetPr( plhs[0] ); plhs[1] = mxCreateDoubleMatrix(ng, r, mxREAL); Eta = (double*)mxGetPr( plhs[1] ); for (intt i = 0; i < p*r; i++) inv_Zeta[i] = 0.0; for (intt i = 0; i < ng*r; i++) Eta[i] = 0.0; //----------------------------------------------------------------------- double *V_sq = new double[p*r]; double *Normalization = new double[r]; for (intt i = 0; i

#include "blas.h"
// mex mexUpdateU.cpp -l blas -largeArrayDims
typedef ptrdiff_t intt;
// for k=1:r, U^k <- Proj_{\ell_2}( U^k + \|V^k\|_2^2 ( XV^k - UVtV^k ) )
// Argument 0: V, whose dim is p x r
// 1: X, whose dim is n x p
// 2: U, whose dim is n x r
// 3: params
//-------------------------------------------------------------------------
void parse_params(const mxArray *params, intt *max_it, bool *pos);
//-------------------------------------------------------------------------
void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] )
{
double *U, *V, *X, *U_output;
double inv_VtV_kk, inv_norm2;
double one_db = 1.0, zero_db = 0.0, minus_one_db = -1.0;
intt one_intt = 1, zero_intt = 0;
intt p = mxGetM(prhs[0]), r = mxGetN(prhs[0]), n = mxGetM(prhs[1]);
intt nr = n*r;
V = (double*)mxGetPr(prhs[0]);
X = (double*)mxGetPr(prhs[1]);
U = (double*)mxGetPr(prhs[2]);
//-----------------------------------------------------------------------
intt max_it;
bool pos;
parse_params( prhs[3], &max_it, &pos );
//-----------------------------------------------------------------------
//U_output that is initialized as a copy of U
plhs[0] = mxCreateDoubleMatrix(n, r, mxREAL);
U_output = (double*)mxGetPr( plhs[0] );
dcopy_ ( &nr, U, &one_intt, U_output, &one_intt );
//-----------------------------------------------------------------------
double *XV = new double[n*r];
double *VtV = new double[r*r];
for (intt i; i < n*r; i++ )
XV[i] = 0.0;
for (intt i; i < r*r; i++ )
VtV[i] = 0.0;
// XV <- X*V
dgemm_ ("N","N",&n,&r,&p, &one_db,X,&n,V,&p,&zero_db, XV,&n);
//VtV <- V^T*V
dsyrk_ ( "L", "T", &r, &p, &one_db, V, &p, &zero_db, VtV, &r );
for (intt i = 0; i < r; i++)
for (intt j = i+1; j < r; j++)
VtV[j*r+i] = VtV[i*r+j];
//-----------------------------------------------------------------------
double *tmp = new double[n];
for (intt t = 1; t <= max_it; t ++ )
for (intt col = 0; col < r; col++)
{
if ( VtV[col + r*col] > 1e-12 )
{
inv_VtV_kk = 1.0 / VtV[col + r*col];
// Tmp <- X_V(:,k)
dcopy_ ( &n, XV+col*n, &one_intt, tmp, &one_intt );
// Tmp <- -U * VtV(:,k) + Tmp
dgemv_ ("N", &n, &r, &minus_one_db, U_output, &n, VtV+col*r, &one_intt, &one_db, tmp, &one_intt );
// U_output(:,k) <- Tmp/VtV(k,k) + U_output(:,k);
daxpy_ (&n, &inv_VtV_kk, tmp, &one_intt, U_output+col*n, &one_intt );
//---------------------------------------------------------
if ( pos )
{
inv_norm2 = 0.0;
// Threshold negative components and compute norm2
for (intt i=0; i