mostFreqClasses = str2double(argv(){4}); numEx = 10; % scenario = 'off'; % solver = 'SHT'; scenario = argv(){3}; solver = argv(){2}; % load data % trainD = dlmread('train_limbs.csv', ',', 1, 0); % testD = dlmread('test_limbs.csv', ',', 1, 0); % D = [trainD; testD]; D = dlmread(argv(){1}, ',', 1, 0); X = D(:, 7 : end); y = D(:, 2); session = D(:, 2); z = D(:, 5); % remove missing values and outliers active = all(X ~= -1, 2); active = active & (z > 2) & (z < 3); % remove least frequent classes tab = tabulate(y); [nil, delc] = sort(tab(:, 2), 'descend'); delc = delc(mostFreqClasses + 1 : end); for c = delc' active(y == c) = false; end % update data X = X(active, :); y = y(active); session = session(active); z = z(active); numClasses = max(y); N = size(X, 1); K = size(X, 2); % normalization X = zscore(X); % training and testing splits splits = [1; find(session(1 : end - 1) ~= session(2 : end)) + 1; N + 1]; ns = length(splits) - 1; trains = cell(1, numEx); tests = cell(1, numEx); for ex = 1 : numEx switch (scenario) case 'off' % offline learning (cross-validation) test = splits(round(ns * (ex - 1) / numEx) + 1) : ... splits(round(ns * ex / numEx) + 1) - 1; train = setdiff(1 : N, test); case 'on' % online learning train = 1 : splits(round(ns * ex / (numEx + 1)) + 1) - 1; test = splits(round(ns * ex / (numEx + 1)) + 1) : ... splits(round(ns * (ex + 1) / (numEx + 1)) + 1) - 1; end trains{ex} = train; tests{ex} = test; end prec = zeros(numEx, 0); recall = zeros(numEx, 0); for ex = 1 : numEx train = trains{ex}; test = tests{ex}; % NB classifier with multivariate Gaussians Py = zeros(numClasses, 1); Pxy = zeros(numClasses, K); Sigma = zeros(K, K); for c = 1 : numClasses sub = train(y(train) == c); if (~isempty(sub)) Py(c) = length(sub) / length(train); Pxy(c, :) = mean(X(sub, :), 1); Sigma = Sigma + Py(c) * cov(X(sub, :)); end end switch (solver) case 'NB' % NB inference logp = repmat(log(Py)', N, 1); for c = 1 : numClasses if (Py(c) > 0) logp(:, c) = log(Py(c)) + log(mvnpdf(X, Pxy(c, :), Sigma)); end end case 'SHT' % sequential hypothesis testing logp = zeros(N, numClasses); for c = 1 : numClasses if (Py(c) > 0) logp(:, c) = log(mvnpdf(X, Pxy(c, :), Sigma)); end end nhyp = zeros(N, 1); for i = 1 : N if ((i == 1) || (session(i - 1) ~= session(i))) logp(i, :) = logp(i, :) + log(Py'); nhyp(i) = 2; else logp(i, :) = logp(i, :) + logp(i - 1, :); nhyp(i) = nhyp(i - 1) + 1; end end end % prediction [conf, yp] = max(logp, [], 2); % sum up all but the highest probability norm1 = logp - repmat(conf, 1, numClasses); norm1((1 : N) + (yp' - 1) * N) = -Inf; norm1 = log(sum(exp(norm1), 2)); % evaluation for i = 1 : 1000 th = 5 - i/10; sub = test(norm1(test) < th); prec(ex, i) = mean(y(sub) == yp(sub)); recall(ex, i) = length(sub) / length(test); end end prec(isnan(prec)) = 1; hold on; plot(100 * mean(recall), 100 * mean(prec), '-', ... 'LineWidth', 1, 'MarkerSize', 4, 'MarkerFaceColor', 'w'); xlabel('Recall [%]'); ylabel('Precision [%]'); hold off; pause pr = [mean(recall)',mean(prec)']; save pr.mat pr; % A = X - Pxy(y, :); % for k = 1 : 9 % subplot(3, 3, k); % hist(A(:, k), -5 : 0.1 : 5); % h = findobj(gca, 'Type', 'patch'); % set(h, 'FaceColor', [0.5, 1, 0.5], 'LineStyle', 'none') % axis([-5, 5, 0, Inf]); % title(sprintf('X_%i', k)); % end