diff options
Diffstat (limited to 'python/trade_dataclasses.py')
| -rw-r--r-- | python/trade_dataclasses.py | 68 |
1 files changed, 36 insertions, 32 deletions
diff --git a/python/trade_dataclasses.py b/python/trade_dataclasses.py index 7292493e..099286c5 100644 --- a/python/trade_dataclasses.py +++ b/python/trade_dataclasses.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field, fields, Field from typing import ClassVar from decimal import Decimal from typing import Literal @@ -194,6 +194,14 @@ class DealType(Enum): Termination = "TERM" +def is_default_init_field(cls, attr): + if hasattr(cls, attr): + default = getattr(cls, attr) + return isinstance(default, Field) and default.init + else: + return False + + class Deal: _conn: ClassVar = dbconn("dawndb", application_name="autobooker") _registry = {} @@ -206,19 +214,19 @@ class Deal: def __class_getitem__(cls, deal_type: DealType): return cls._registry[deal_type] - def __init_subclass__( - cls, deal_type: DealType, table_name: str, insert_ignore=(), select_ignore=() - ): + def __init_subclass__(cls, deal_type: DealType, table_name: str, insert_ignore=()): super().__init_subclass__() cls._registry[deal_type] = cls cls._table_name = table_name insert_columns = [c for c in cls.__annotations__ if c not in insert_ignore] - insert_place_holders = ",".join(["%s"] * len(insert_columns)) - select_columns = [c for c in cls.__annotations__ if c not in select_ignore] + place_holders = ",".join(["%s"] * len(insert_columns)) + select_columns = [ + c for c in cls.__annotations__ if is_default_init_field(cls, c) + ] - cls._sql_insert = f"INSERT INTO {cls._table_name}({','.join(insert_columns)}) VALUES({insert_place_holders})" + cls._sql_insert = f"INSERT INTO {table_name}({','.join(insert_columns)}) VALUES({place_holders})" cls._sql_select = ( - f"SELECT {','.join(select_columns)} FROM {cls._table_name} WHERE id=%s" + f"SELECT {','.join(select_columns)} FROM {table_name} WHERE id=%s" ) def stage(self): @@ -229,13 +237,6 @@ class Deal: if f.metadata.get("insert", True) ] ) - self._select_queue.append( - [ - getattr(self, f.name) - for f in fields(self) - if f.metadata.get("insert", True) - ] - ) @classmethod def commit(cls): @@ -518,7 +519,6 @@ class TerminationDeal( deal_type=DealType.Termination, table_name="terminations", insert_ignore=("id", "dealid", "orig_cp", "currency", "fund", "product_type"), - select_ignore=("orig_cp", "currency", "fund", "product_type"), ): termination_fee: float = field(metadata={"mtm": "Initial Payment"}) fee_payment_date: datetime.date = field(metadata={"mtm": "Settle Date"}) @@ -530,30 +530,36 @@ class TerminationDeal( id: int = field(default=None, metadata={"insert": False}) dealid: str = field(default=None, metadata={"insert": False, "mtm": "Swap ID"}) orig_cp: str = field( - default=None, - metadata={"mtm": "Remaining Party", "insert": False, "select": False}, + init=False, + metadata={"mtm": "Remaining Party", "insert": False}, ) currency: str = field( - default=None, - metadata={"mtm": "Currency Code", "insert": False, "select": False}, + init=False, + metadata={"mtm": "Currency Code", "insert": False}, ) fund: str = field( - default=None, - metadata={"mtm": "Account Abbreviation", "insert": False, "select": False}, + init=False, + metadata={"mtm": "Account Abbreviation", "insert": False}, ) product_type: str = field( - default=None, metadata={"mtm": "Product Type", "insert": False, "select": False} + init=False, metadata={"mtm": "Product Type", "insert": False} ) def __post_init__(self): + if self.dealid.startswith("SWPTN"): + self.product_type = "CDISW" + table_name = "swaptions" + elif self.dealid.startwith("SCCDS"): + self.product_type = "TRN" + table_name = "cds" + sql_str = ( + "SELECT cp_code, currency, fund FROM terminations " + f"LEFT JOIN {table_name} USING (dealid) " + "WHERE terminations.id = %s" + ) with self._conn.cursor() as c: - termination_information = ( - """SELECT coalesce(cds.cp_code, swaptions.cp_code) AS orig_cp, COALESCE (cds.currency, swaptions.currency) AS currency, """ - """COALESCE (cds.swap_type, 'SWAPTION') as product_type, COALESCE (cds.fund, swaptions.fund) as fund FROM terminations LEFT JOIN cds USING (dealid) """ - """LEFT JOIN swaptions USING (dealid) WHERE terminations.id=%s;""" - ) - c.execute(termination_information, (self.id,)) - self.orig_cp, self.currency, self.product_type, self.fund = c.fetchone() + c.execute(sql_str, (self.id,)) + self.orig_cp, self.currency, self.fund = c.fetchone() def to_markit(self): obj = self.serialize("mtm") @@ -562,8 +568,6 @@ class TerminationDeal( else: obj["Initial Payment"] = abs(obj["Initial Payment"]) obj["Transaction Code"] = "Pay" - swap_type = {"CD_INDEX_TRANCHE": "TRN", "SWAPTION": "CDISW"} - obj["Product Type"] = swap_type[obj["Product Type"]] obj["Trade ID"] = obj["Swap ID"] + "-" + str(obj["id"]) obj["Transaction Type"] = ( "Termination" |
