aboutsummaryrefslogtreecommitdiffstats
path: root/python/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/optimization.py')
-rw-r--r--python/optimization.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/python/optimization.py b/python/optimization.py
index 65ba02c4..6eda84fa 100644
--- a/python/optimization.py
+++ b/python/optimization.py
@@ -1,9 +1,10 @@
import numpy as np
+import numpy.linalg as la
def decr_fun(a, b):
return np.sqrt(np.square(a).sum() + np.square(b).sum())
-def KLfit(P, q, b = np.ones(P.shape[0])):
+def KLfit(P, q, b):
"""
solves for x, min KL(x, q)
s.t. Px = b
@@ -12,7 +13,10 @@ def KLfit(P, q, b = np.ones(P.shape[0])):
"""
alpha = 0.4
beta = 0.8
- P = np.vstack(P, np.ones(P.shape[1]))
+ if(len(P.shape)==1):
+ P = P[None]
+ P = np.vstack([P, np.ones(P.shape[1])])
+ b = np.hstack([b, 1])
#init x and nu
x = np.ones(P.shape[1])/P.shape[1]
nu = np.ones(P.shape[0])
@@ -24,13 +28,13 @@ def KLfit(P, q, b = np.ones(P.shape[0])):
rprimal = P.dot(x) - b
S = -P.dot((x*P).T)
Dnu_pd = la.solve(S, (P*x).dot(rdual) - rprimal)
- Dx_pd = -x[:,np.newaxis]*(P.T.dot(Dnu_pd)+rdual)
+ Dx_pd = -x*(P.T.dot(Dnu_pd)+rdual)
#backtracking search
phi = 1
newx = x + phi * Dx_pd
newnu = nu + phi * Dnu_pd
- while newx[newx>0].sum() > 0:
+ while sum(newx<0) > 0:
phi *= beta
newx = x + phi * Dx_pd
newnu = nu + phi * Dnu_pd
@@ -54,3 +58,6 @@ def KLfit(P, q, b = np.ones(P.shape[0])):
if __name__=="__main__":
#write some small test
+ x = np.array([5, 4, 3, 2, 1])
+ w = np.ones(5)/5
+ test = KLfit(x, w, 2.5)