import datetime import os import psycopg2 from psycopg2.extras import DictCursor, NamedTupleCursor from psycopg2 import IntegrityError, DataError from psycopg2.extensions import DateFromPy, register_adapter, AsIs from psycopg2.pool import ThreadedConnectionPool from sqlalchemy import create_engine from sqlalchemy.engine.url import URL import numpy as np import atexit class InfDateAdapter: def __init__(self, wrapped): self.wrapped = wrapped def getquoted(self): if self.wrapped == datetime.date.max: return b"'infinity'::date" elif self.wrapped == datetime.date.min: return b"'-infinity'::date" else: return psycopg2.extensions.DateFromPy(self.wrapped).getquoted() def nan_to_null(f, _NULL=AsIs("NULL"), _Float=psycopg2.extensions.Float): if not np.isnan(f): return _Float(f) return _NULL register_adapter(datetime.date, InfDateAdapter) register_adapter(np.int64, lambda x: AsIs(x)) register_adapter(np.float, nan_to_null) def dbconn(dbname, cursor_factory=NamedTupleCursor): if dbname == "etdb": dbname = "ET" user_name = "et_user" else: user_name = dbname[:-2] + "_user" return psycopg2.connect( database=dbname, user=user_name, host=os.environ.get("PGHOST", "debian"), cursor_factory=cursor_factory, options="-c extra_float_digits=3", ) def dbengine(dbname, cursor_factory=NamedTupleCursor): if dbname in ["rmbs_model", "corelogic", "crt"]: uri = URL( drivername="mysql+mysqlconnector", host="debian", database=dbname, query={"option_files": os.path.expanduser("~/.my.cnf")}, ) return create_engine(uri, paramstyle="format") else: if dbname == "etdb": dbname = "ET" user_name = "et_user" else: user_name = dbname[:-2] + "_user" uri = URL( drivername="postgresql", host=os.environ.get("PGHOST", "debian"), username=user_name, database=dbname, query={"options": "-c extra_float_digits=3"}, ) return create_engine( uri, paramstyle="format", connect_args={"cursor_factory": cursor_factory} ) def with_connection(dbname): def decorator(f): conn = dbconn(dbname) def with_connection_(*args, **kwargs): # or use a pool, or a factory function... try: rv = f(conn, *args, **kwargs) except Exception as e: print(e) conn.rollback() else: return rv return with_connection_ return decorator def query_db(conn, sqlstr, params=None, one=True): with conn.cursor() as c: if params: c.execute(sqlstr, params) else: c.execute(sqlstr) conn.commit() r = c.fetchone() if one else c.fetchall() return r serenitas_pool = ThreadedConnectionPool( 0, 5, database="serenitasdb", user="serenitas_user", host=os.environ.get("PGHOST", "debian"), cursor_factory=DictCursor, ) @atexit.register def close_db(): serenitas_pool.closeall() serenitas_engine = dbengine("serenitasdb") dawn_engine = dbengine("dawndb")