L1 penalized logistic regression ranking of variables syntax: [ bestVariables bestToWorst ] = sortVariablesLR( featureVect, classLabels, ... topVarsToKeep, topVarsToSearch ) Inputs: featureVect: all the the data samples in (dim x numSamples) classLabels: all class labels (0 for not learned, 1 for learned, 2 unsure. The ones labeled class 2 will not be used. topVarsToKeep: index of number of best variables to return topVarsToSearch: To find topVarsToKeep, we look at the most frequently top ranked variables within the (first:topVarsToSearch ) variables across all folds of leave subject out cross validation Outputs: bestVariables: indices of best variables to separate the classes bestToWorst: index ordering all the variables for all CV folds
0001 % L1 penalized logistic regression ranking of variables 0002 % 0003 % syntax: [ bestVariables bestToWorst ] = sortVariablesLR( featureVect, classLabels, ... 0004 % topVarsToKeep, topVarsToSearch ) 0005 % 0006 % Inputs: 0007 % featureVect: all the the data samples in (dim x numSamples) 0008 % classLabels: all class labels (0 for not learned, 1 for learned, 2 unsure. 0009 % The ones labeled class 2 will not be used. 0010 % topVarsToKeep: index of number of best variables to return 0011 % topVarsToSearch: To find topVarsToKeep, we look at the most frequently 0012 % top ranked variables within the (first:topVarsToSearch ) variables 0013 % across all folds of leave subject out cross validation 0014 % 0015 % Outputs: 0016 % bestVariables: indices of best variables to separate the classes 0017 % bestToWorst: index ordering all the variables for all CV folds 0018 % 0019 0020 function [ bestVariables bestToWorst ] = sortVariablesLR( featureVect, classLabels, ... 0021 topVarsToKeep, topVarsToSearch ) 0022 0023 if nargin < 3 || isempty( topVarsToKeep) 0024 topVarsToKeep = 10; 0025 end 0026 if nargin < 4 || isempty( topVarsToSearch ) 0027 topVarsToSearch = topVarsToKeep; 0028 end 0029 0030 0031 % leave one subject out cross validation 0032 [ dim numSamples] = size( featureVect); 0033 % expLabels = getLeave1OutLabels( numSamples, numSamplesPerSubj); 0034 numTrials = 1; % length(expLabels); 0035 bestToWorst = zeros( topVarsToSearch, numTrials); 0036 0037 % center and scale variables to unit variance 0038 featureVect = featureVect - repmat( mean(featureVect,2), [1,numSamples] ); 0039 featStdev = std( featureVect, 0, 2); 0040 featureVect( featStdev ~= 0,:) = featureVect( featStdev ~= 0,:)./repmat(featStdev(featStdev ~= 0), [1,numSamples]); 0041 0042 % init some thing for LR 0043 featureVect = [ones(numSamples,1) featureVect']'; % Add Bias element to features (at top) 0044 classLabels( classLabels == 0) = -1; % Convert y to {-1,1} representation 0045 baseLambda = ones(dim+1,1); % [ 1./(std( featureVect, 0, 2)+.001)]; % 15 0046 options = struct('verbose',0); 0047 0048 0049 for i1 = 1; %:numTrials 0050 0051 % init weight and lambda scalar every trial 0052 w = zeros( dim+1,1); % make sure it goes into while loop 0053 lambdaScalar = 150; 0054 0055 0056 trainLabels = classLabels; %(:,expLabels(i1).train); 0057 trainFeatures = featureVect; %(:,expLabels(i1).train); 0058 trainFeatures( :, trainLabels==2) = []; 0059 trainLabels( :, trainLabels==2) = []; 0060 funObj = @(w)LogisticLoss(w,trainFeatures',trainLabels'); % LR objective 0061 0062 % do the tests on each feature 0063 k1 = 0; 0064 while nnz( w) < topVarsToSearch && k1 < 500 0065 k1 = k1+1; 0066 0067 lambda = lambdaScalar*baseLambda; 0068 lambda(1) = 0; % Do not penalize bias variable 0069 w = L1GeneralProjection(funObj,w,lambda, options ); 0070 lambdaScalar = lambdaScalar/1.1; 0071 end 0072 0073 % sort by most important, and put in the matrix 0074 w(1) = []; 0075 [ temp bestToWorst] = sort( abs(w(:)), 'descend'); 0076 % bestToWorst( :,i1) = wIdx(1:topVarsToSearch); 0077 0078 end 0079 0080 % remove redundancies 0081 featureVect(1,:) = []; %to remove that unit offset 0082 0083 unqIdx = findRedundancies( featureVect(bestToWorst,:) ); 0084 bestToWorst = bestToWorst(unqIdx); 0085 0086 0087 bestVariables = bestToWorst(1:min(topVarsToKeep, length(bestToWorst) )); 0088 0089 0090