diff options
| author | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-11-28 15:14:52 -0500 |
|---|---|---|
| committer | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-11-28 15:14:52 -0500 |
| commit | 041aa021657a3c290952b222e3141449638bad19 (patch) | |
| tree | 087dbbbfc1f08d207e6a443fc91e900b3a588d19 /simulation/vi_theano.py | |
| parent | e0152021025e07788e5ef928af6c9160c98c5452 (diff) | |
| download | cascades-041aa021657a3c290952b222e3141449638bad19.tar.gz | |
switch to python 3 + blocks version of MLE implemented
Diffstat (limited to 'simulation/vi_theano.py')
| -rw-r--r-- | simulation/vi_theano.py | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/simulation/vi_theano.py b/simulation/vi_theano.py index 9e8fdb0..cfd1dcf 100644 --- a/simulation/vi_theano.py +++ b/simulation/vi_theano.py @@ -5,11 +5,11 @@ import theano.tensor.shared_randomstreams import numpy as np n_cascades = 1000 -n_nodes = 4 +n_nodes = 3 n_samples = 100 srng = tsr.shared_randomstreams.RandomStreams(seed=123) -lr = 5*1e-3 -n_epochs = 20 +lr = 1e-1 +n_epochs = 1 ###############Variational Inference#################### @@ -70,28 +70,36 @@ if __name__ == "__main__": graph[k, k] = 0 p = 0.5 graph = np.log(1. / (1 - p * graph)) + + graph = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]]) + p = 0.5 + graph = np.log(1. / (1 - p * graph)) + cascades = mn.build_cascade_list(mn.simulate_cascades(n_cascades, graph), collapse=True) x_obs, s_obs = cascades[0], cascades[1] #mle lkl_plot = [] - if 0: + if 1: for i in range(n_epochs): - for xt, st in zip(x_obs, s_obs): + for k in xrange(len(x_obs)/100): + xt = x_obs[k*100:(k+1)*100] + st = s_obs[k*100:(k+1)*100] lkl = train_mle(xt, st) lkl_plot.append(lkl) + print(params.get_value()) print(graph) w = params.get_value() for k in range(len(w)): w[k, k] = 0 print(w) - import matplotlib.pyplot as plt - plt.plot(lkl_plot) - plt.show() + #import matplotlib.pyplot as plt + #plt.plot(lkl_plot) + #plt.show() #variational inference - if 1: + if 0: for i in range(n_epochs): train_kl() for k in xrange(len(x_obs)/100): |
