aboutsummaryrefslogtreecommitdiffstats
path: root/python/experiments/test_matrix.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/experiments/test_matrix.py')
-rw-r--r--python/experiments/test_matrix.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/python/experiments/test_matrix.py b/python/experiments/test_matrix.py
new file mode 100644
index 00000000..ae7e9d0b
--- /dev/null
+++ b/python/experiments/test_matrix.py
@@ -0,0 +1,60 @@
+import numpy as np
+import scipy.linalg as splinalg
+import dask.array as da
+import timeit
+import os
+import ctypes
+from ctypes.util import find_library
+openblas_lib = ctypes.cdll.LoadLibrary(find_library('openblas'))
+
+def get_num_threads():
+ return openblas_lib.openblas_get_num_threads()
+
+def set_num_threads(n):
+ openblas_lib.openblas_set_num_threads(int(n))
+
+seed = 1234
+np.random.seed(seed)
+
+N = 1000000
+p = 100
+X = np.random.random(N * p).reshape((N, p), order='F')
+XT = X.T.copy()
+true_value=33334547.40257686
+#X = da.from_array(X, chunks=(N/4, p))
+old_num_threads = get_num_threads()
+def test():
+ if not np.isclose(np.trace(X.T.dot(X)), true_value):
+ raise ValueError()
+
+def test2():
+ if not np.isclose(np.trace(splinalg.blas.dsyrk(1., X, trans=1)), true_value):
+ raise ValueError()
+
+def test3():
+ if not np.isclose(np.trace(XT.dot(X)), true_value):
+ raise ValueError()
+
+t = timeit.timeit(test, number=5)
+print("Multi threaded computation with {} threads: {}".format(old_num_threads, t))
+
+t = timeit.timeit(test2, number=5)
+print("Multi threaded computation dsyrk with {} threads: {}".format(old_num_threads, t))
+
+t = timeit.timeit(test3, number=5)
+print("Multi threaded computation dgemv with {} threads: {}".format(old_num_threads, t))
+
+set_num_threads(1)
+t = timeit.timeit(test, number=5)
+print("Non multi-threaded computation:{}".format(t))
+#set_num_threads(old_num_threads)
+
+print("using dask array")
+for Nchunk in [1, 2, 4]:
+ X = da.from_array(X, chunks=(N / Nchunk, p))
+ def test_dask():
+ if not np.isclose(np.trace(X.T.dot(X).compute()), true_value, Nchunk):
+ raise ValueError()
+ t = timeit.timeit(test_dask, number=5)
+ print("Dask computation {} chunk: {}".format(Nchunk, t))
+