diff options
Diffstat (limited to 'python/utils')
| -rw-r--r-- | python/utils/db.py | 88 |
1 files changed, 56 insertions, 32 deletions
diff --git a/python/utils/db.py b/python/utils/db.py index 8ede3756..fcc3ab0d 100644 --- a/python/utils/db.py +++ b/python/utils/db.py @@ -10,6 +10,7 @@ from sqlalchemy.engine.url import URL import numpy as np import atexit + class InfDateAdapter: def __init__(self, wrapped): self.wrapped = wrapped @@ -23,51 +24,63 @@ class InfDateAdapter: return psycopg2.extensions.DateFromPy(self.wrapped).getquoted() -def nan_to_null(f, _NULL=AsIs('NULL'), - _Float=psycopg2.extensions.Float): +def nan_to_null(f, _NULL=AsIs("NULL"), _Float=psycopg2.extensions.Float): if not np.isnan(f): return _Float(f) return _NULL + register_adapter(datetime.date, InfDateAdapter) register_adapter(np.int64, lambda x: AsIs(x)) register_adapter(np.float, nan_to_null) + def dbconn(dbname, cursor_factory=NamedTupleCursor): - if dbname == 'etdb': - dbname = 'ET' - user_name = 'et_user' + if dbname == "etdb": + dbname = "ET" + user_name = "et_user" else: - user_name = dbname[:-2] + '_user' - return psycopg2.connect(database=dbname, - user=user_name, - host=os.environ.get("PGHOST", "debian"), - cursor_factory=cursor_factory, - options="-c extra_float_digits=3") + user_name = dbname[:-2] + "_user" + return psycopg2.connect( + database=dbname, + user=user_name, + host=os.environ.get("PGHOST", "debian"), + cursor_factory=cursor_factory, + options="-c extra_float_digits=3", + ) + def dbengine(dbname, cursor_factory=NamedTupleCursor): - if dbname in ['rmbs_model', 'corelogic']: - uri = URL(drivername="mysql+mysqlconnector", - host="debian", database=dbname, - query={'option_files': os.path.expanduser('~/.my.cnf')}) + if dbname in ["rmbs_model", "corelogic", "crt"]: + uri = URL( + drivername="mysql+mysqlconnector", + host="debian", + database=dbname, + query={"option_files": os.path.expanduser("~/.my.cnf")}, + ) return create_engine(uri, paramstyle="format") else: - if dbname == 'etdb': - dbname= 'ET' - user_name = 'et_user' + if dbname == "etdb": + dbname = "ET" + user_name = "et_user" else: - user_name = dbname[:-2] + '_user' - uri = URL(drivername="postgresql", - host=os.environ.get("PGHOST", "debian"), - username=user_name, - database=dbname, - query={"options": "-c extra_float_digits=3"}) - return create_engine(uri, paramstyle="format", - connect_args={'cursor_factory': cursor_factory}) + user_name = dbname[:-2] + "_user" + uri = URL( + drivername="postgresql", + host=os.environ.get("PGHOST", "debian"), + username=user_name, + database=dbname, + query={"options": "-c extra_float_digits=3"}, + ) + return create_engine( + uri, paramstyle="format", connect_args={"cursor_factory": cursor_factory} + ) + def with_connection(dbname): def decorator(f): conn = dbconn(dbname) + def with_connection_(*args, **kwargs): # or use a pool, or a factory function... try: @@ -77,9 +90,12 @@ def with_connection(dbname): conn.rollback() else: return rv + return with_connection_ + return decorator + def query_db(conn, sqlstr, params=None, one=True): with conn.cursor() as c: if params: @@ -90,13 +106,21 @@ def query_db(conn, sqlstr, params=None, one=True): r = c.fetchone() if one else c.fetchall() return r -serenitas_pool = ThreadedConnectionPool(0, 5, database='serenitasdb', - user='serenitas_user', - host=os.environ.get("PGHOST", "debian"), - cursor_factory=DictCursor) + +serenitas_pool = ThreadedConnectionPool( + 0, + 5, + database="serenitasdb", + user="serenitas_user", + host=os.environ.get("PGHOST", "debian"), + cursor_factory=DictCursor, +) + + @atexit.register def close_db(): serenitas_pool.closeall() -serenitas_engine = dbengine('serenitasdb') -dawn_engine = dbengine('dawndb') + +serenitas_engine = dbengine("serenitasdb") +dawn_engine = dbengine("dawndb") |
