diff options
Diffstat (limited to 'python/experiments/test_matrix.py')
| -rw-r--r-- | python/experiments/test_matrix.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/python/experiments/test_matrix.py b/python/experiments/test_matrix.py new file mode 100644 index 00000000..ae7e9d0b --- /dev/null +++ b/python/experiments/test_matrix.py @@ -0,0 +1,60 @@ +import numpy as np +import scipy.linalg as splinalg +import dask.array as da +import timeit +import os +import ctypes +from ctypes.util import find_library +openblas_lib = ctypes.cdll.LoadLibrary(find_library('openblas')) + +def get_num_threads(): + return openblas_lib.openblas_get_num_threads() + +def set_num_threads(n): + openblas_lib.openblas_set_num_threads(int(n)) + +seed = 1234 +np.random.seed(seed) + +N = 1000000 +p = 100 +X = np.random.random(N * p).reshape((N, p), order='F') +XT = X.T.copy() +true_value=33334547.40257686 +#X = da.from_array(X, chunks=(N/4, p)) +old_num_threads = get_num_threads() +def test(): + if not np.isclose(np.trace(X.T.dot(X)), true_value): + raise ValueError() + +def test2(): + if not np.isclose(np.trace(splinalg.blas.dsyrk(1., X, trans=1)), true_value): + raise ValueError() + +def test3(): + if not np.isclose(np.trace(XT.dot(X)), true_value): + raise ValueError() + +t = timeit.timeit(test, number=5) +print("Multi threaded computation with {} threads: {}".format(old_num_threads, t)) + +t = timeit.timeit(test2, number=5) +print("Multi threaded computation dsyrk with {} threads: {}".format(old_num_threads, t)) + +t = timeit.timeit(test3, number=5) +print("Multi threaded computation dgemv with {} threads: {}".format(old_num_threads, t)) + +set_num_threads(1) +t = timeit.timeit(test, number=5) +print("Non multi-threaded computation:{}".format(t)) +#set_num_threads(old_num_threads) + +print("using dask array") +for Nchunk in [1, 2, 4]: + X = da.from_array(X, chunks=(N / Nchunk, p)) + def test_dask(): + if not np.isclose(np.trace(X.T.dot(X).compute()), true_value, Nchunk): + raise ValueError() + t = timeit.timeit(test_dask, number=5) + print("Dask computation {} chunk: {}".format(Nchunk, t)) + |
