aboutsummaryrefslogtreecommitdiffstats
path: root/python/trade_dataclasses.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/trade_dataclasses.py')
-rw-r--r--python/trade_dataclasses.py68
1 files changed, 36 insertions, 32 deletions
diff --git a/python/trade_dataclasses.py b/python/trade_dataclasses.py
index 7292493e..099286c5 100644
--- a/python/trade_dataclasses.py
+++ b/python/trade_dataclasses.py
@@ -1,4 +1,4 @@
-from dataclasses import dataclass, field, fields
+from dataclasses import dataclass, field, fields, Field
from typing import ClassVar
from decimal import Decimal
from typing import Literal
@@ -194,6 +194,14 @@ class DealType(Enum):
Termination = "TERM"
+def is_default_init_field(cls, attr):
+ if hasattr(cls, attr):
+ default = getattr(cls, attr)
+ return isinstance(default, Field) and default.init
+ else:
+ return False
+
+
class Deal:
_conn: ClassVar = dbconn("dawndb", application_name="autobooker")
_registry = {}
@@ -206,19 +214,19 @@ class Deal:
def __class_getitem__(cls, deal_type: DealType):
return cls._registry[deal_type]
- def __init_subclass__(
- cls, deal_type: DealType, table_name: str, insert_ignore=(), select_ignore=()
- ):
+ def __init_subclass__(cls, deal_type: DealType, table_name: str, insert_ignore=()):
super().__init_subclass__()
cls._registry[deal_type] = cls
cls._table_name = table_name
insert_columns = [c for c in cls.__annotations__ if c not in insert_ignore]
- insert_place_holders = ",".join(["%s"] * len(insert_columns))
- select_columns = [c for c in cls.__annotations__ if c not in select_ignore]
+ place_holders = ",".join(["%s"] * len(insert_columns))
+ select_columns = [
+ c for c in cls.__annotations__ if is_default_init_field(cls, c)
+ ]
- cls._sql_insert = f"INSERT INTO {cls._table_name}({','.join(insert_columns)}) VALUES({insert_place_holders})"
+ cls._sql_insert = f"INSERT INTO {table_name}({','.join(insert_columns)}) VALUES({place_holders})"
cls._sql_select = (
- f"SELECT {','.join(select_columns)} FROM {cls._table_name} WHERE id=%s"
+ f"SELECT {','.join(select_columns)} FROM {table_name} WHERE id=%s"
)
def stage(self):
@@ -229,13 +237,6 @@ class Deal:
if f.metadata.get("insert", True)
]
)
- self._select_queue.append(
- [
- getattr(self, f.name)
- for f in fields(self)
- if f.metadata.get("insert", True)
- ]
- )
@classmethod
def commit(cls):
@@ -518,7 +519,6 @@ class TerminationDeal(
deal_type=DealType.Termination,
table_name="terminations",
insert_ignore=("id", "dealid", "orig_cp", "currency", "fund", "product_type"),
- select_ignore=("orig_cp", "currency", "fund", "product_type"),
):
termination_fee: float = field(metadata={"mtm": "Initial Payment"})
fee_payment_date: datetime.date = field(metadata={"mtm": "Settle Date"})
@@ -530,30 +530,36 @@ class TerminationDeal(
id: int = field(default=None, metadata={"insert": False})
dealid: str = field(default=None, metadata={"insert": False, "mtm": "Swap ID"})
orig_cp: str = field(
- default=None,
- metadata={"mtm": "Remaining Party", "insert": False, "select": False},
+ init=False,
+ metadata={"mtm": "Remaining Party", "insert": False},
)
currency: str = field(
- default=None,
- metadata={"mtm": "Currency Code", "insert": False, "select": False},
+ init=False,
+ metadata={"mtm": "Currency Code", "insert": False},
)
fund: str = field(
- default=None,
- metadata={"mtm": "Account Abbreviation", "insert": False, "select": False},
+ init=False,
+ metadata={"mtm": "Account Abbreviation", "insert": False},
)
product_type: str = field(
- default=None, metadata={"mtm": "Product Type", "insert": False, "select": False}
+ init=False, metadata={"mtm": "Product Type", "insert": False}
)
def __post_init__(self):
+ if self.dealid.startswith("SWPTN"):
+ self.product_type = "CDISW"
+ table_name = "swaptions"
+ elif self.dealid.startwith("SCCDS"):
+ self.product_type = "TRN"
+ table_name = "cds"
+ sql_str = (
+ "SELECT cp_code, currency, fund FROM terminations "
+ f"LEFT JOIN {table_name} USING (dealid) "
+ "WHERE terminations.id = %s"
+ )
with self._conn.cursor() as c:
- termination_information = (
- """SELECT coalesce(cds.cp_code, swaptions.cp_code) AS orig_cp, COALESCE (cds.currency, swaptions.currency) AS currency, """
- """COALESCE (cds.swap_type, 'SWAPTION') as product_type, COALESCE (cds.fund, swaptions.fund) as fund FROM terminations LEFT JOIN cds USING (dealid) """
- """LEFT JOIN swaptions USING (dealid) WHERE terminations.id=%s;"""
- )
- c.execute(termination_information, (self.id,))
- self.orig_cp, self.currency, self.product_type, self.fund = c.fetchone()
+ c.execute(sql_str, (self.id,))
+ self.orig_cp, self.currency, self.fund = c.fetchone()
def to_markit(self):
obj = self.serialize("mtm")
@@ -562,8 +568,6 @@ class TerminationDeal(
else:
obj["Initial Payment"] = abs(obj["Initial Payment"])
obj["Transaction Code"] = "Pay"
- swap_type = {"CD_INDEX_TRANCHE": "TRN", "SWAPTION": "CDISW"}
- obj["Product Type"] = swap_type[obj["Product Type"]]
obj["Trade ID"] = obj["Swap ID"] + "-" + str(obj["id"])
obj["Transaction Type"] = (
"Termination"