diff options
| author | Thibaut Horel <thibaut.horel@gmail.com> | 2015-09-14 19:10:35 -0400 |
|---|---|---|
| committer | Thibaut Horel <thibaut.horel@gmail.com> | 2015-09-14 19:10:35 -0400 |
| commit | 0cd7df0103230be0b73f888a6b5553440ef40d03 (patch) | |
| tree | 081e1e4a4682656a52eb739549e71a24b0c2073c | |
| parent | ee158d02d92c597a35bc2e8704a85989a547f0e4 (diff) | |
| download | criminal_cascades-0cd7df0103230be0b73f888a6b5553440ef40d03.tar.gz | |
Small tweak which might help with numerical precision
| -rw-r--r-- | hawkes/data.py | 2 | ||||
| -rw-r--r-- | hawkes/main.py | 49 | ||||
| -rw-r--r-- | hawkes/refine.py | 2 |
3 files changed, 30 insertions, 23 deletions
diff --git a/hawkes/data.py b/hawkes/data.py index e5c33f8..5f68dd1 100644 --- a/hawkes/data.py +++ b/hawkes/data.py @@ -14,7 +14,7 @@ def parse(s): def fluctuation_int(t): if t is None: t = MAX_TIME - return t + 0.43 / 0.0172 * (cos(4.36) - cos(0.0172 * t + 4.36)) + return (t, t + 0.43 / 0.0172 * (cos(4.36) - cos(0.0172 * t + 4.36))) def load_nodes(filename): diff --git a/hawkes/main.py b/hawkes/main.py index bc319cc..8f30dcf 100644 --- a/hawkes/main.py +++ b/hawkes/main.py @@ -34,19 +34,28 @@ def iter_events(events): yield (n, t) +def approx(x): + if x > 1e-10: + return 1 - exp(-x) + else: + return x + + def ll(lamb, alpha, mu): r1 = sum(log(lamb * (1 + 0.43 * sin(0.0172 * t1 + 4.36)) + sum(alpha / d * mu * exp(-mu * (t1 - t2)) for (n2, t2, d) in s)) for ((n1, t1), s) in event_edges.iteritems()) - r2 = sum(sum(alpha / d * (1 - exp(-mu * (nodes[n2] - t1))) - for n2, d in edges[n1].iteritems() if nodes[n2] > t1) + r2 = sum(sum(alpha / d * approx(mu * (nodes[n2][0] - t1)) + for n2, d in edges[n1].iteritems() + if nodes[n2][0] > t1) for (n1, t1) in iter_events(events)) - r3 = lamb * sum(nodes.itervalues()) + r3 = lamb * sum(node[1] for node in nodes.itervalues()) + print r1, r2, r3 return -(r1 - r2 - r3) -def sa(x, y, z, sigma=0.5, niter=70, fc=None): +def sa(x, y, z, sigma=0.5, niter=1000, fc=None): T = 0.1 e = 1.1 if fc: @@ -57,8 +66,8 @@ def sa(x, y, z, sigma=0.5, niter=70, fc=None): sys.stderr.write("sa: " + " ".join(map(str, [T, sigma, x, y, z, fo])) + "\n") sys.stderr.flush() - yn = max(y + gauss(0, sigma * y + 1e-10), 0) - zn = max(z + gauss(0, sigma * z + 1e-10), 0) + yn = max(y + gauss(0, sigma * y), 0) + zn = max(z + gauss(0, sigma * z), 0) fn = ll(x, yn, zn) if fn < fo or exp((fo - fn) / T) > random(): y = yn @@ -73,7 +82,7 @@ def sa(x, y, z, sigma=0.5, niter=70, fc=None): return y, z, fo -def optimize_with_sa(x, y, z, niter=200): +def optimize_with_sa(x, y, z, niter=10): def f(x): return ll(x, y, z) @@ -87,7 +96,7 @@ def optimize_with_sa(x, y, z, niter=200): sys.stdout.flush() -def optimize_with_gss(x, y, z, niter=100): +def optimize_with_gss(x, y, z, niter=5): def f(x): return ll(x, y, z) @@ -99,8 +108,8 @@ def optimize_with_gss(x, y, z, niter=100): return ll(x, y, z) for _ in xrange(niter): - y, fc = gss(g, 0, 100, tol=1e-10) - z, fc = gss(h, 0, 100, tol=1e-10) + y, fc = gss(g, 0, 1e50, tol=1e-10) + z, fc = gss(h, 0, 1e50, tol=1e-10) x, fc = gss(f, 0, 1e-3, tol=1e-10) print x, y, z, fc sys.stdout.flush() @@ -127,18 +136,16 @@ def coarse_search(): if __name__ == "__main__": - nodes, edges, events, event_edges = load(open("data2.pickle", "rb")) - - x = 1.2e-5 - y = 0.002 - z = 0.004 - # sa(x, y, z) + nodes, edges, events, event_edges = load(open("data-dist1.pickle", "rb")) + x = 1.25e-5 + y = 1.2e16 + z = 1.5e-20 + sa(x, y, z) - with open(sys.argv[1]) as fh: - l = [map(float, line.strip().split()[:3]) for line in fh] - for e in l: - optimize_with_sa(*e) + # with open(sys.argv[1]) as fh: + # l = [map(float, line.strip().split()[:3]) for line in fh] + # for e in l: + # optimize_with_gss(*e) - # optimize_with_gss(x, y, z) # print ll(x, y, z) # coarse_search() diff --git a/hawkes/refine.py b/hawkes/refine.py index 3dfcfcc..0c171bc 100644 --- a/hawkes/refine.py +++ b/hawkes/refine.py @@ -58,5 +58,5 @@ def refine(): if __name__ == "__main__": - nodes, edges, events, event_edges = load(open("data2.pickle", "rb")) + nodes, edges, events, event_edges = load(open("data-dist1.pickle", "rb")) refine() |
