aboutsummaryrefslogtreecommitdiffstats
path: root/python/utils
diff options
context:
space:
mode:
Diffstat (limited to 'python/utils')
-rw-r--r--python/utils/db.py88
1 files changed, 56 insertions, 32 deletions
diff --git a/python/utils/db.py b/python/utils/db.py
index 8ede3756..fcc3ab0d 100644
--- a/python/utils/db.py
+++ b/python/utils/db.py
@@ -10,6 +10,7 @@ from sqlalchemy.engine.url import URL
import numpy as np
import atexit
+
class InfDateAdapter:
def __init__(self, wrapped):
self.wrapped = wrapped
@@ -23,51 +24,63 @@ class InfDateAdapter:
return psycopg2.extensions.DateFromPy(self.wrapped).getquoted()
-def nan_to_null(f, _NULL=AsIs('NULL'),
- _Float=psycopg2.extensions.Float):
+def nan_to_null(f, _NULL=AsIs("NULL"), _Float=psycopg2.extensions.Float):
if not np.isnan(f):
return _Float(f)
return _NULL
+
register_adapter(datetime.date, InfDateAdapter)
register_adapter(np.int64, lambda x: AsIs(x))
register_adapter(np.float, nan_to_null)
+
def dbconn(dbname, cursor_factory=NamedTupleCursor):
- if dbname == 'etdb':
- dbname = 'ET'
- user_name = 'et_user'
+ if dbname == "etdb":
+ dbname = "ET"
+ user_name = "et_user"
else:
- user_name = dbname[:-2] + '_user'
- return psycopg2.connect(database=dbname,
- user=user_name,
- host=os.environ.get("PGHOST", "debian"),
- cursor_factory=cursor_factory,
- options="-c extra_float_digits=3")
+ user_name = dbname[:-2] + "_user"
+ return psycopg2.connect(
+ database=dbname,
+ user=user_name,
+ host=os.environ.get("PGHOST", "debian"),
+ cursor_factory=cursor_factory,
+ options="-c extra_float_digits=3",
+ )
+
def dbengine(dbname, cursor_factory=NamedTupleCursor):
- if dbname in ['rmbs_model', 'corelogic']:
- uri = URL(drivername="mysql+mysqlconnector",
- host="debian", database=dbname,
- query={'option_files': os.path.expanduser('~/.my.cnf')})
+ if dbname in ["rmbs_model", "corelogic", "crt"]:
+ uri = URL(
+ drivername="mysql+mysqlconnector",
+ host="debian",
+ database=dbname,
+ query={"option_files": os.path.expanduser("~/.my.cnf")},
+ )
return create_engine(uri, paramstyle="format")
else:
- if dbname == 'etdb':
- dbname= 'ET'
- user_name = 'et_user'
+ if dbname == "etdb":
+ dbname = "ET"
+ user_name = "et_user"
else:
- user_name = dbname[:-2] + '_user'
- uri = URL(drivername="postgresql",
- host=os.environ.get("PGHOST", "debian"),
- username=user_name,
- database=dbname,
- query={"options": "-c extra_float_digits=3"})
- return create_engine(uri, paramstyle="format",
- connect_args={'cursor_factory': cursor_factory})
+ user_name = dbname[:-2] + "_user"
+ uri = URL(
+ drivername="postgresql",
+ host=os.environ.get("PGHOST", "debian"),
+ username=user_name,
+ database=dbname,
+ query={"options": "-c extra_float_digits=3"},
+ )
+ return create_engine(
+ uri, paramstyle="format", connect_args={"cursor_factory": cursor_factory}
+ )
+
def with_connection(dbname):
def decorator(f):
conn = dbconn(dbname)
+
def with_connection_(*args, **kwargs):
# or use a pool, or a factory function...
try:
@@ -77,9 +90,12 @@ def with_connection(dbname):
conn.rollback()
else:
return rv
+
return with_connection_
+
return decorator
+
def query_db(conn, sqlstr, params=None, one=True):
with conn.cursor() as c:
if params:
@@ -90,13 +106,21 @@ def query_db(conn, sqlstr, params=None, one=True):
r = c.fetchone() if one else c.fetchall()
return r
-serenitas_pool = ThreadedConnectionPool(0, 5, database='serenitasdb',
- user='serenitas_user',
- host=os.environ.get("PGHOST", "debian"),
- cursor_factory=DictCursor)
+
+serenitas_pool = ThreadedConnectionPool(
+ 0,
+ 5,
+ database="serenitasdb",
+ user="serenitas_user",
+ host=os.environ.get("PGHOST", "debian"),
+ cursor_factory=DictCursor,
+)
+
+
@atexit.register
def close_db():
serenitas_pool.closeall()
-serenitas_engine = dbengine('serenitasdb')
-dawn_engine = dbengine('dawndb')
+
+serenitas_engine = dbengine("serenitasdb")
+dawn_engine = dbengine("dawndb")