aboutsummaryrefslogtreecommitdiffstats
path: root/python/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/optimization.py')
-rw-r--r--python/optimization.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/python/optimization.py b/python/optimization.py
new file mode 100644
index 00000000..65ba02c4
--- /dev/null
+++ b/python/optimization.py
@@ -0,0 +1,56 @@
+import numpy as np
+
+def decr_fun(a, b):
+ return np.sqrt(np.square(a).sum() + np.square(b).sum())
+
+def KLfit(P, q, b = np.ones(P.shape[0])):
+ """
+ solves for x, min KL(x, q)
+ s.t. Px = b
+ x>=0
+ sum(x) = 1
+ """
+ alpha = 0.4
+ beta = 0.8
+ P = np.vstack(P, np.ones(P.shape[1]))
+ #init x and nu
+ x = np.ones(P.shape[1])/P.shape[1]
+ nu = np.ones(P.shape[0])
+ decr = np.Inf
+ eps = 1e-12
+ niter = 1
+ while decr > eps and niter < 500:
+ rdual = 1 + np.log(x) - np.log(q) + P.T.dot(nu)
+ rprimal = P.dot(x) - b
+ S = -P.dot((x*P).T)
+ Dnu_pd = la.solve(S, (P*x).dot(rdual) - rprimal)
+ Dx_pd = -x[:,np.newaxis]*(P.T.dot(Dnu_pd)+rdual)
+ #backtracking search
+ phi = 1
+ newx = x + phi * Dx_pd
+ newnu = nu + phi * Dnu_pd
+
+ while newx[newx>0].sum() > 0:
+ phi *= beta
+ newx = x + phi * Dx_pd
+ newnu = nu + phi * Dnu_pd
+
+ newrdual = 1 + np.log(newx) - np.log(q) + P.T.dot(newnu)
+ newrprimal = P.dot(newx) - b
+
+ while decr_fun(newrdual, newrprimal) > (1 - alpha * phi) * decr_fun(rdual, rprimal):
+ phi *= beta
+ newx = x + phi * Dx_pd
+ newnu = nu + phi * Dnu_pd
+ newrdual = 1 + np.log(newx) - np.log(q) + P.T.dot(newnu)
+ newrprimal = P.dot(newx) - b
+
+ x = newx
+ nu = newnu
+ decr = decr_fun(newrdual, newrprimal)
+ niter += 1
+
+ return {"obj": sum(x*(np.log(x) - np.log(q))), "weight":x, "status": decr}
+
+if __name__=="__main__":
+ #write some small test