aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/utils.py
diff options
context:
space:
mode:
authorThibaut Horel <thibaut.horel@gmail.com>2015-11-30 20:01:39 -0500
committerThibaut Horel <thibaut.horel@gmail.com>2015-11-30 20:01:39 -0500
commitf457b3480d53920a0d8b0e3b8cdb2b18601088ee (patch)
treee68d25279027014bc1c9bf7229464d7ae1f3c6c4 /simulation/utils.py
parentf1762904c648b2089031ba6ce46ccaaac4f3514c (diff)
downloadcascades-f457b3480d53920a0d8b0e3b8cdb2b18601088ee.tar.gz
Add constant source
Diffstat (limited to 'simulation/utils.py')
-rw-r--r--simulation/utils.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/simulation/utils.py b/simulation/utils.py
index aad7771..37cabf0 100644
--- a/simulation/utils.py
+++ b/simulation/utils.py
@@ -49,6 +49,15 @@ def random_source(graph, node_p=None):
return x0
+def constant_source(graph, source):
+ if type(source) == int:
+ x = np.zeros(graph.shape[0], dtype=bool)
+ x[source] = True
+ return x
+ else:
+ return source
+
+
def simulate_cascades(n_obs, graph, source=random_source):
n_nodes = graph.shape[0]
x_obs = np.zeros((n_obs, n_nodes), dtype=bool)