aboutsummaryrefslogtreecommitdiffstats
path: root/python/utils/db.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/utils/db.py')
-rw-r--r--python/utils/db.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/python/utils/db.py b/python/utils/db.py
new file mode 100644
index 00000000..8ede3756
--- /dev/null
+++ b/python/utils/db.py
@@ -0,0 +1,102 @@
+import datetime
+import os
+import psycopg2
+from psycopg2.extras import DictCursor, NamedTupleCursor
+from psycopg2 import IntegrityError, DataError
+from psycopg2.extensions import DateFromPy, register_adapter, AsIs
+from psycopg2.pool import ThreadedConnectionPool
+from sqlalchemy import create_engine
+from sqlalchemy.engine.url import URL
+import numpy as np
+import atexit
+
+class InfDateAdapter:
+ def __init__(self, wrapped):
+ self.wrapped = wrapped
+
+ def getquoted(self):
+ if self.wrapped == datetime.date.max:
+ return b"'infinity'::date"
+ elif self.wrapped == datetime.date.min:
+ return b"'-infinity'::date"
+ else:
+ return psycopg2.extensions.DateFromPy(self.wrapped).getquoted()
+
+
+def nan_to_null(f, _NULL=AsIs('NULL'),
+ _Float=psycopg2.extensions.Float):
+ if not np.isnan(f):
+ return _Float(f)
+ return _NULL
+
+register_adapter(datetime.date, InfDateAdapter)
+register_adapter(np.int64, lambda x: AsIs(x))
+register_adapter(np.float, nan_to_null)
+
+def dbconn(dbname, cursor_factory=NamedTupleCursor):
+ if dbname == 'etdb':
+ dbname = 'ET'
+ user_name = 'et_user'
+ else:
+ user_name = dbname[:-2] + '_user'
+ return psycopg2.connect(database=dbname,
+ user=user_name,
+ host=os.environ.get("PGHOST", "debian"),
+ cursor_factory=cursor_factory,
+ options="-c extra_float_digits=3")
+
+def dbengine(dbname, cursor_factory=NamedTupleCursor):
+ if dbname in ['rmbs_model', 'corelogic']:
+ uri = URL(drivername="mysql+mysqlconnector",
+ host="debian", database=dbname,
+ query={'option_files': os.path.expanduser('~/.my.cnf')})
+ return create_engine(uri, paramstyle="format")
+ else:
+ if dbname == 'etdb':
+ dbname= 'ET'
+ user_name = 'et_user'
+ else:
+ user_name = dbname[:-2] + '_user'
+ uri = URL(drivername="postgresql",
+ host=os.environ.get("PGHOST", "debian"),
+ username=user_name,
+ database=dbname,
+ query={"options": "-c extra_float_digits=3"})
+ return create_engine(uri, paramstyle="format",
+ connect_args={'cursor_factory': cursor_factory})
+
+def with_connection(dbname):
+ def decorator(f):
+ conn = dbconn(dbname)
+ def with_connection_(*args, **kwargs):
+ # or use a pool, or a factory function...
+ try:
+ rv = f(conn, *args, **kwargs)
+ except Exception as e:
+ print(e)
+ conn.rollback()
+ else:
+ return rv
+ return with_connection_
+ return decorator
+
+def query_db(conn, sqlstr, params=None, one=True):
+ with conn.cursor() as c:
+ if params:
+ c.execute(sqlstr, params)
+ else:
+ c.execute(sqlstr)
+ conn.commit()
+ r = c.fetchone() if one else c.fetchall()
+ return r
+
+serenitas_pool = ThreadedConnectionPool(0, 5, database='serenitasdb',
+ user='serenitas_user',
+ host=os.environ.get("PGHOST", "debian"),
+ cursor_factory=DictCursor)
+@atexit.register
+def close_db():
+ serenitas_pool.closeall()
+
+serenitas_engine = dbengine('serenitasdb')
+dawn_engine = dbengine('dawndb')