diff options
| -rw-r--r-- | python/db.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/python/db.py b/python/db.py index 0ba5f7b3..8a7a21f2 100644 --- a/python/db.py +++ b/python/db.py @@ -2,7 +2,7 @@ import datetime import os import psycopg2 from psycopg2.extras import DictCursor -from psycopg2 import IntegrityError +from psycopg2 import IntegrityError, DataError from psycopg2.extensions import register_adapter, AsIs from sqlalchemy import create_engine from sqlalchemy.engine.url import URL @@ -19,8 +19,15 @@ class InfDateAdapter: else: return psycopg2.extensions.DateFromPy(self.wrapped).getquoted() +def nan_to_null(f, _NULL=AsIs('NULL'), + _Float=psycopg2.extensions.Float): + if not np.isnan(f): + return _Float(f) + return _NULL + register_adapter(datetime.date, InfDateAdapter) register_adapter(np.int64, lambda x: AsIs(x)) +register_adapter(np.float, nan_to_null) def dbconn(dbname): if dbname == 'etdb': @@ -33,13 +40,6 @@ def dbconn(dbname): host="debian", cursor_factory=DictCursor) - -def nan_to_null(f, _NULL=psycopg2.extensions.AsIs('NULL'), - _Float=psycopg2.extensions.Float): - if not np.isnan(f): - return _Float(f) - return _NULL - def dbengine(dbname): if dbname == 'rmbs_model': uri = URL(drivername="mysql+mysqlconnector", @@ -48,7 +48,7 @@ def dbengine(dbname): else: uri = "postgresql://{0}@debian/{1}".format(dbname[:-2] + '_user', dbname) - return create_engine(uri) + return create_engine(uri, paramstyle="format") def with_connection(dbname): def decorator(f): |
