From 0482865b3fc128964584db1af66be9fb0783a4af Mon Sep 17 00:00:00 2001 From: jeanpouget-abadie Date: Sun, 1 Feb 2015 19:21:59 -0500 Subject: small changes to optimization functions --- src/convex_optimization.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) (limited to 'src/convex_optimization.py') diff --git a/src/convex_optimization.py b/src/convex_optimization.py index 02787d3..1d84db2 100644 --- a/src/convex_optimization.py +++ b/src/convex_optimization.py @@ -6,9 +6,44 @@ import timeout import cvxopt - @timeout.timeout(20) def sparse_recovery(M_val, w_val, lbda): + """ + Solves: + min lbda * |theta|_1 - sum b^i log(exp(-) -1) - M*theta + s.t theta_j <= 0 + """ + assert len(M_val) == len(w_val) + + if M_val.dtype == bool: + M_val = M_val.astype('float32') + + if type(lbda) == int: + lbda = np.array(lbda) + + theta = tensor.row().T + theta_ = theta.flatten() + + M = theano.shared(M_val.astype(theano.config.floatX)) + w = theano.shared(w_val.astype(theano.config.floatX)) + lbda = theano.shared(lbda.astype(theano.config.floatX)) + + y_tmp = lbda * (theta_).norm(1) - tensor.sum(tensor.dot(M, theta_ + )) + w.dot(tensor.log(tensor.exp(-M.dot(theta_))-1)) + y_diff = tensor.grad(y_tmp, theta_) + y_hess = theano.gradient.hessian(y_tmp, theta_) + f = function([theta_], y_hess) + + print(f(-1*np.random.rand(M_val.shape[1]).astype("float32"))) + + y = lbda * (theta_).norm(1) - tensor.sum(tensor.dot(M, theta_ + )) - w.dot(tensor.log(tensor.exp(-M.dot(theta_))-1)) + + return diff_and_opt(theta, theta_, M, M_val, w, lbda, y) + + +@timeout.timeout(20) +def type_lasso(M_val, w_val, lbda): """ Solves: min - sum_j theta_j + lbda*|e^{M*theta} - (1 - w)|_2 @@ -63,7 +98,7 @@ def diff_and_opt(theta, theta_, M, M_val, w, lbda, y): G = cvxopt.spdiag([1 for i in range(n)]) h = cvxopt.matrix(0.0, (n,1)) - cvxopt.solvers.options['show_progress'] = False + cvxopt.solvers.options['show_progress'] = True try: theta = cvxopt.solvers.cp(F, G, h)['x'] except ArithmeticError: -- cgit v1.2.3-70-g09d2