diff options
Diffstat (limited to 'data/combined/scriptSkeleton.m')
| -rw-r--r--[-rwxr-xr-x] | data/combined/scriptSkeleton.m | 76 |
1 files changed, 50 insertions, 26 deletions
diff --git a/data/combined/scriptSkeleton.m b/data/combined/scriptSkeleton.m index 9752c48..378f158 100755..100644 --- a/data/combined/scriptSkeleton.m +++ b/data/combined/scriptSkeleton.m @@ -1,12 +1,17 @@ +mostFreqClasses = str2double(argv(){4});
+numEx = 10;
+% scenario = 'off';
+% solver = 'SHT';
+scenario = argv(){3};
+solver = argv(){2};
+
+% load data
% trainD = dlmread('train_limbs.csv', ',', 1, 0);
% testD = dlmread('test_limbs.csv', ',', 1, 0);
-% trainD = dlmread(argv(){1}, ',', 1, 0);
-% testD = dlmread(argv(){2}, ',', 1, 0);
% D = [trainD; testD];
-% D = dlmread(argv(){1}, ',', 1, 0);
-D = dlmread('mean.csv', ',', 1, 0);
-X = D(:, 6 : end);
+D = dlmread(argv(){1}, ',', 1, 0);
+X = D(:, 7 : end);
y = D(:, 2);
session = D(:, 2);
z = D(:, 5);
@@ -15,7 +20,7 @@ z = D(:, 5); active = all(X ~= -1, 2);
active = active & (z > 2) & (z < 3);
-mostFreqClasses = 100;
+% remove least frequent classes
tab = tabulate(y);
[nil, delc] = sort(tab(:, 2), 'descend');
delc = delc(mostFreqClasses + 1 : end);
@@ -23,6 +28,7 @@ for c = delc' active(y == c) = false;
end
+% update data
X = X(active, :);
y = y(active);
session = session(active);
@@ -34,19 +40,35 @@ K = size(X, 2); % normalization
X = zscore(X);
-numEx = 10;
+% training and testing splits
splits = [1; find(session(1 : end - 1) ~= session(2 : end)) + 1; N + 1];
ns = length(splits) - 1;
-prec = [];
-recall = [];
+trains = cell(1, numEx);
+tests = cell(1, numEx);
for ex = 1 : numEx
- test = splits(round(ns * (ex - 1) / numEx) + 1) : ...
- splits(round(ns * ex / numEx) + 1) - 1;
- train = setdiff(1 : N, test);
-% train = 1 : splits(round(ns * ex / numEx) + 1) - 1;
-% test = splits(round(ns * ex / numEx) + 1) : ...
-% splits(round(ns * (ex + 1) / numEx) + 1) - 1;
+ switch (scenario)
+ case 'off'
+ % offline learning (cross-validation)
+ test = splits(round(ns * (ex - 1) / numEx) + 1) : ...
+ splits(round(ns * ex / numEx) + 1) - 1;
+ train = setdiff(1 : N, test);
+
+ case 'on'
+ % online learning
+ train = 1 : splits(round(ns * ex / (numEx + 1)) + 1) - 1;
+ test = splits(round(ns * ex / (numEx + 1)) + 1) : ...
+ splits(round(ns * (ex + 1) / (numEx + 1)) + 1) - 1;
+ end
+ trains{ex} = train;
+ tests{ex} = test;
+end
+
+prec = zeros(numEx, 0);
+recall = zeros(numEx, 0);
+for ex = 1 : numEx
+ train = trains{ex};
+ test = tests{ex};
% NB classifier with multivariate Gaussians
Py = zeros(numClasses, 1);
@@ -60,9 +82,7 @@ for ex = 1 : numEx Sigma = Sigma + Py(c) * cov(X(sub, :));
end
end
- % Sigma = diag(diag(Sigma));
-
- solver = 'SHT';
+
switch (solver)
case 'NB'
% NB inference
@@ -92,17 +112,20 @@ for ex = 1 : numEx nhyp(i) = nhyp(i - 1) + 1;
end
end
- logp = logp ./ repmat(nhyp, 1, numClasses);
end
% prediction
[conf, yp] = max(logp, [], 2);
- conf = exp(conf) ./ sum(exp(logp), 2);
+
+ % sum up all but the highest probability
+ norm1 = logp - repmat(conf, 1, numClasses);
+ norm1((1 : N) + (yp' - 1) * N) = -Inf;
+ norm1 = log(sum(exp(norm1), 2));
% evaluation
- for i = 1 : 50
- th = 1 - 1 / (1.33 ^ i);
- sub = test(conf(test) > th);
+ for i = 1 : 1000
+ th = 5 - i/10;
+ sub = test(norm1(test) < th);
prec(ex, i) = mean(y(sub) == yp(sub));
recall(ex, i) = length(sub) / length(test);
end
@@ -110,13 +133,14 @@ end prec(isnan(prec)) = 1;
hold on;
-plot(100 * mean(recall), 100 * mean(prec), '-ro', ...
- 'LineWidth', 1, 'MarkerSize', 4, 'MarkerFaceColor', 'w')
+plot(100 * mean(recall), 100 * mean(prec), '-', ...
+ 'LineWidth', 1, 'MarkerSize', 4, 'MarkerFaceColor', 'w');
xlabel('Recall [%]');
ylabel('Precision [%]');
hold off;
pause
-
+pr = [mean(recall)',mean(prec)'];
+save pr.mat pr;
% A = X - Pxy(y, :);
% for k = 1 : 9
% subplot(3, 3, k);
|
