diff options
Diffstat (limited to 'python/optimization.py')
| -rw-r--r-- | python/optimization.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/python/optimization.py b/python/optimization.py index 65ba02c4..6eda84fa 100644 --- a/python/optimization.py +++ b/python/optimization.py @@ -1,9 +1,10 @@ import numpy as np +import numpy.linalg as la def decr_fun(a, b): return np.sqrt(np.square(a).sum() + np.square(b).sum()) -def KLfit(P, q, b = np.ones(P.shape[0])): +def KLfit(P, q, b): """ solves for x, min KL(x, q) s.t. Px = b @@ -12,7 +13,10 @@ def KLfit(P, q, b = np.ones(P.shape[0])): """ alpha = 0.4 beta = 0.8 - P = np.vstack(P, np.ones(P.shape[1])) + if(len(P.shape)==1): + P = P[None] + P = np.vstack([P, np.ones(P.shape[1])]) + b = np.hstack([b, 1]) #init x and nu x = np.ones(P.shape[1])/P.shape[1] nu = np.ones(P.shape[0]) @@ -24,13 +28,13 @@ def KLfit(P, q, b = np.ones(P.shape[0])): rprimal = P.dot(x) - b S = -P.dot((x*P).T) Dnu_pd = la.solve(S, (P*x).dot(rdual) - rprimal) - Dx_pd = -x[:,np.newaxis]*(P.T.dot(Dnu_pd)+rdual) + Dx_pd = -x*(P.T.dot(Dnu_pd)+rdual) #backtracking search phi = 1 newx = x + phi * Dx_pd newnu = nu + phi * Dnu_pd - while newx[newx>0].sum() > 0: + while sum(newx<0) > 0: phi *= beta newx = x + phi * Dx_pd newnu = nu + phi * Dnu_pd @@ -54,3 +58,6 @@ def KLfit(P, q, b = np.ones(P.shape[0])): if __name__=="__main__": #write some small test + x = np.array([5, 4, 3, 2, 1]) + w = np.ones(5)/5 + test = KLfit(x, w, 2.5) |
