diff options
Diffstat (limited to 'python/utils/db.py')
| -rw-r--r-- | python/utils/db.py | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/python/utils/db.py b/python/utils/db.py new file mode 100644 index 00000000..8ede3756 --- /dev/null +++ b/python/utils/db.py @@ -0,0 +1,102 @@ +import datetime +import os +import psycopg2 +from psycopg2.extras import DictCursor, NamedTupleCursor +from psycopg2 import IntegrityError, DataError +from psycopg2.extensions import DateFromPy, register_adapter, AsIs +from psycopg2.pool import ThreadedConnectionPool +from sqlalchemy import create_engine +from sqlalchemy.engine.url import URL +import numpy as np +import atexit + +class InfDateAdapter: + def __init__(self, wrapped): + self.wrapped = wrapped + + def getquoted(self): + if self.wrapped == datetime.date.max: + return b"'infinity'::date" + elif self.wrapped == datetime.date.min: + return b"'-infinity'::date" + else: + return psycopg2.extensions.DateFromPy(self.wrapped).getquoted() + + +def nan_to_null(f, _NULL=AsIs('NULL'), + _Float=psycopg2.extensions.Float): + if not np.isnan(f): + return _Float(f) + return _NULL + +register_adapter(datetime.date, InfDateAdapter) +register_adapter(np.int64, lambda x: AsIs(x)) +register_adapter(np.float, nan_to_null) + +def dbconn(dbname, cursor_factory=NamedTupleCursor): + if dbname == 'etdb': + dbname = 'ET' + user_name = 'et_user' + else: + user_name = dbname[:-2] + '_user' + return psycopg2.connect(database=dbname, + user=user_name, + host=os.environ.get("PGHOST", "debian"), + cursor_factory=cursor_factory, + options="-c extra_float_digits=3") + +def dbengine(dbname, cursor_factory=NamedTupleCursor): + if dbname in ['rmbs_model', 'corelogic']: + uri = URL(drivername="mysql+mysqlconnector", + host="debian", database=dbname, + query={'option_files': os.path.expanduser('~/.my.cnf')}) + return create_engine(uri, paramstyle="format") + else: + if dbname == 'etdb': + dbname= 'ET' + user_name = 'et_user' + else: + user_name = dbname[:-2] + '_user' + uri = URL(drivername="postgresql", + host=os.environ.get("PGHOST", "debian"), + username=user_name, + database=dbname, + query={"options": "-c extra_float_digits=3"}) + return create_engine(uri, paramstyle="format", + connect_args={'cursor_factory': cursor_factory}) + +def with_connection(dbname): + def decorator(f): + conn = dbconn(dbname) + def with_connection_(*args, **kwargs): + # or use a pool, or a factory function... + try: + rv = f(conn, *args, **kwargs) + except Exception as e: + print(e) + conn.rollback() + else: + return rv + return with_connection_ + return decorator + +def query_db(conn, sqlstr, params=None, one=True): + with conn.cursor() as c: + if params: + c.execute(sqlstr, params) + else: + c.execute(sqlstr) + conn.commit() + r = c.fetchone() if one else c.fetchall() + return r + +serenitas_pool = ThreadedConnectionPool(0, 5, database='serenitasdb', + user='serenitas_user', + host=os.environ.get("PGHOST", "debian"), + cursor_factory=DictCursor) +@atexit.register +def close_db(): + serenitas_pool.closeall() + +serenitas_engine = dbengine('serenitasdb') +dawn_engine = dbengine('dawndb') |
