aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/utils_blocks.py
diff options
context:
space:
mode:
authorjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-12-02 15:43:34 -0500
committerjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-12-02 15:43:34 -0500
commitde815c196ad03e5d76cc675696b9cd7c1b3b3fbb (patch)
tree3fb1cc6932e02a80c8955e31d900a0c7de268d70 /simulation/utils_blocks.py
parentccb192c4190701531094b46df85725158d4e9ffc (diff)
downloadcascades-de815c196ad03e5d76cc675696b9cd7c1b3b3fbb.tar.gz
changing active learning definition plus main loops of mle/vi_blocks
Diffstat (limited to 'simulation/utils_blocks.py')
-rw-r--r--simulation/utils_blocks.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/simulation/utils_blocks.py b/simulation/utils_blocks.py
index 3b29972..72a6881 100644
--- a/simulation/utils_blocks.py
+++ b/simulation/utils_blocks.py
@@ -31,16 +31,14 @@ class ActiveLearning(be.SimpleExtension):
Extension which updates the node_p array passed to the get_data method of
LearnedDataset
"""
- def __init__(self, dataset, **kwargs):
+ def __init__(self, dataset, params, **kwargs):
super(ActiveLearning, self).__init__(**kwargs)
self.dataset = dataset
+ self.params = params
def do(self, which_callback, *args):
- out_degree = np.sum(self.dataset.graph, axis=1)
- self.dataset.node_p = out_degree / np.sum(out_degree)
-
-# def do(self, which_callback, *args):
-
+ exp_out_par = np.exp(np.sum(self.params, axis=1))
+ self.dataset.node_p = exp_out_par / np.sum(exp_out_par)
class JSONDump(be.SimpleExtension):