import numpy as np import numpy.linalg as la def decr_fun(a, b): return np.sqrt(np.square(a).sum() + np.square(b).sum()) def KLfit(P, q, b): """ solves for x, min KL(x, q) s.t. Px = b x>=0 sum(x) = 1 """ alpha = 0.4 beta = 0.8 if(len(P.shape)==1): P = P[None] P = np.vstack([P, np.ones(P.shape[1])]) b = np.hstack([b, 1]) #init x and nu x = np.ones(P.shape[1])/P.shape[1] nu = np.ones(P.shape[0]) decr = np.Inf eps = 1e-12 niter = 1 while decr > eps and niter < 500: rdual = 1 + np.log(x) - np.log(q) + P.T.dot(nu) rprimal = P.dot(x) - b S = -P.dot((x*P).T) Dnu_pd = la.solve(S, (P*x).dot(rdual) - rprimal) Dx_pd = -x*(P.T.dot(Dnu_pd)+rdual) #backtracking search phi = 1 newx = x + phi * Dx_pd newnu = nu + phi * Dnu_pd while sum(newx<0) > 0: phi *= beta newx = x + phi * Dx_pd newnu = nu + phi * Dnu_pd newrdual = 1 + np.log(newx) - np.log(q) + P.T.dot(newnu) newrprimal = P.dot(newx) - b while decr_fun(newrdual, newrprimal) > (1 - alpha * phi) * decr_fun(rdual, rprimal): phi *= beta newx = x + phi * Dx_pd newnu = nu + phi * Dnu_pd newrdual = 1 + np.log(newx) - np.log(q) + P.T.dot(newnu) newrprimal = P.dot(newx) - b x = newx nu = newnu decr = decr_fun(newrdual, newrprimal) niter += 1 return {"obj": sum(x*(np.log(x) - np.log(q))), "weight":x, "status": decr} if __name__=="__main__": #write some small test P = np.array([[5, 4, 3, 2, 1],[3, 3, 4, 3, 3]]) w = np.ones(5)/5 b = np.array([2.5, 3.3]) test = KLfit(P, w, b)