diff options
Diffstat (limited to 'simulation/utils.py')
| -rw-r--r-- | simulation/utils.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/simulation/utils.py b/simulation/utils.py index aad7771..37cabf0 100644 --- a/simulation/utils.py +++ b/simulation/utils.py @@ -49,6 +49,15 @@ def random_source(graph, node_p=None): return x0 +def constant_source(graph, source): + if type(source) == int: + x = np.zeros(graph.shape[0], dtype=bool) + x[source] = True + return x + else: + return source + + def simulate_cascades(n_obs, graph, source=random_source): n_nodes = graph.shape[0] x_obs = np.zeros((n_obs, n_nodes), dtype=bool) |
