From cbdd6e52d0bf02e308a2b885637f9a8b4fb1a5e0 Mon Sep 17 00:00:00 2001 From: Jon Whiteaker Date: Thu, 1 Mar 2012 22:03:08 -0800 Subject: new script --- data/combined/scriptSkeleton.m | 76 +++++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 26 deletions(-) mode change 100755 => 100644 data/combined/scriptSkeleton.m (limited to 'data/combined/scriptSkeleton.m') diff --git a/data/combined/scriptSkeleton.m b/data/combined/scriptSkeleton.m old mode 100755 new mode 100644 index 9752c48..378f158 --- a/data/combined/scriptSkeleton.m +++ b/data/combined/scriptSkeleton.m @@ -1,12 +1,17 @@ +mostFreqClasses = str2double(argv(){4}); +numEx = 10; +% scenario = 'off'; +% solver = 'SHT'; +scenario = argv(){3}; +solver = argv(){2}; + +% load data % trainD = dlmread('train_limbs.csv', ',', 1, 0); % testD = dlmread('test_limbs.csv', ',', 1, 0); -% trainD = dlmread(argv(){1}, ',', 1, 0); -% testD = dlmread(argv(){2}, ',', 1, 0); % D = [trainD; testD]; -% D = dlmread(argv(){1}, ',', 1, 0); -D = dlmread('mean.csv', ',', 1, 0); -X = D(:, 6 : end); +D = dlmread(argv(){1}, ',', 1, 0); +X = D(:, 7 : end); y = D(:, 2); session = D(:, 2); z = D(:, 5); @@ -15,7 +20,7 @@ z = D(:, 5); active = all(X ~= -1, 2); active = active & (z > 2) & (z < 3); -mostFreqClasses = 100; +% remove least frequent classes tab = tabulate(y); [nil, delc] = sort(tab(:, 2), 'descend'); delc = delc(mostFreqClasses + 1 : end); @@ -23,6 +28,7 @@ for c = delc' active(y == c) = false; end +% update data X = X(active, :); y = y(active); session = session(active); @@ -34,19 +40,35 @@ K = size(X, 2); % normalization X = zscore(X); -numEx = 10; +% training and testing splits splits = [1; find(session(1 : end - 1) ~= session(2 : end)) + 1; N + 1]; ns = length(splits) - 1; -prec = []; -recall = []; +trains = cell(1, numEx); +tests = cell(1, numEx); for ex = 1 : numEx - test = splits(round(ns * (ex - 1) / numEx) + 1) : ... - splits(round(ns * ex / numEx) + 1) - 1; - train = setdiff(1 : N, test); -% train = 1 : splits(round(ns * ex / numEx) + 1) - 1; -% test = splits(round(ns * ex / numEx) + 1) : ... -% splits(round(ns * (ex + 1) / numEx) + 1) - 1; + switch (scenario) + case 'off' + % offline learning (cross-validation) + test = splits(round(ns * (ex - 1) / numEx) + 1) : ... + splits(round(ns * ex / numEx) + 1) - 1; + train = setdiff(1 : N, test); + + case 'on' + % online learning + train = 1 : splits(round(ns * ex / (numEx + 1)) + 1) - 1; + test = splits(round(ns * ex / (numEx + 1)) + 1) : ... + splits(round(ns * (ex + 1) / (numEx + 1)) + 1) - 1; + end + trains{ex} = train; + tests{ex} = test; +end + +prec = zeros(numEx, 0); +recall = zeros(numEx, 0); +for ex = 1 : numEx + train = trains{ex}; + test = tests{ex}; % NB classifier with multivariate Gaussians Py = zeros(numClasses, 1); @@ -60,9 +82,7 @@ for ex = 1 : numEx Sigma = Sigma + Py(c) * cov(X(sub, :)); end end - % Sigma = diag(diag(Sigma)); - - solver = 'SHT'; + switch (solver) case 'NB' % NB inference @@ -92,17 +112,20 @@ for ex = 1 : numEx nhyp(i) = nhyp(i - 1) + 1; end end - logp = logp ./ repmat(nhyp, 1, numClasses); end % prediction [conf, yp] = max(logp, [], 2); - conf = exp(conf) ./ sum(exp(logp), 2); + + % sum up all but the highest probability + norm1 = logp - repmat(conf, 1, numClasses); + norm1((1 : N) + (yp' - 1) * N) = -Inf; + norm1 = log(sum(exp(norm1), 2)); % evaluation - for i = 1 : 50 - th = 1 - 1 / (1.33 ^ i); - sub = test(conf(test) > th); + for i = 1 : 1000 + th = 5 - i/10; + sub = test(norm1(test) < th); prec(ex, i) = mean(y(sub) == yp(sub)); recall(ex, i) = length(sub) / length(test); end @@ -110,13 +133,14 @@ end prec(isnan(prec)) = 1; hold on; -plot(100 * mean(recall), 100 * mean(prec), '-ro', ... - 'LineWidth', 1, 'MarkerSize', 4, 'MarkerFaceColor', 'w') +plot(100 * mean(recall), 100 * mean(prec), '-', ... + 'LineWidth', 1, 'MarkerSize', 4, 'MarkerFaceColor', 'w'); xlabel('Recall [%]'); ylabel('Precision [%]'); hold off; pause - +pr = [mean(recall)',mean(prec)']; +save pr.mat pr; % A = X - Pxy(y, :); % for k = 1 : 9 % subplot(3, 3, k); -- cgit v1.2.3-70-g09d2