aboutsummaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/trade_dataclasses.py55
1 files changed, 50 insertions, 5 deletions
diff --git a/python/trade_dataclasses.py b/python/trade_dataclasses.py
index e764ec85..5f4a2eb3 100644
--- a/python/trade_dataclasses.py
+++ b/python/trade_dataclasses.py
@@ -1,14 +1,18 @@
from dataclasses import dataclass, field, fields, Field
+from enum import Enum
+from io import StringIO
+from headers import DealType, MTM_HEADERS
from typing import ClassVar
from decimal import Decimal
from typing import Literal
import csv
import datetime
-from enum import Enum
from psycopg.types.numeric import Int2BinaryDumper
from psycopg import adapters
from serenitas.analytics.dates import next_business_day, previous_twentieth
from serenitas.utils.db2 import dbconn
+from serenitas.utils.env import DAILY_DIR
+from serenitas.utils.remote import SftpClient
from lru import LRU
from psycopg.errors import UniqueViolation
import logging
@@ -225,13 +229,13 @@ class Deal:
cls._table_name = table_name
insert_columns = [c for c in cls.__annotations__ if c not in insert_ignore]
place_holders = ",".join(["%s"] * len(insert_columns))
- select_columns = [
- c for c in cls.__annotations__ if is_default_init_field(cls, c)
- ]
+ cls._sql_fields = set(
+ [c for c in cls.__annotations__ if is_default_init_field(cls, c)]
+ )
cls._sql_insert = f"INSERT INTO {table_name}({','.join(insert_columns)}) VALUES({place_holders})"
cls._sql_select = (
- f"SELECT {','.join(select_columns)} FROM {table_name} WHERE id=%s"
+ f"SELECT {','.join(cls._sql_fields)} FROM {table_name} WHERE id=%s"
)
def stage(self):
@@ -303,9 +307,49 @@ class BbgDeal:
cls.commit()
+class MTMDeal:
+ _mtm_queue: ClassVar[list] = []
+ _mtm_headers = None
+ _sftp = SftpClient.from_creds("mtm")
+ product_type: str
+
+ def __init_subclass__(cls, deal_type, **kwargs):
+ super().__init_subclass__(deal_type, **kwargs)
+ cls._mtm_headers = MTM_HEADERS[deal_type]
+ if deal_type == DealType.Swaption:
+ cls.product_type = "CDISW"
+ elif deal_type == DealType.CDS:
+ cls.product_type = "TRN"
+
+ @classmethod
+ def mtm_upload(cls):
+ if not cls._mtm_queue: # early exit
+ return
+ buf = StringIO()
+ csvwriter = csv.writer(buf)
+ csvwriter.writerow(cls._mtm_headers)
+ csvwriter.writerows(
+ [row.get(h, None) for h in cls._mtm_headers for row in cls._mtm_queue]
+ )
+ buf = buf.getvalue().encode()
+ fname = f"MTM.{datetime.datetime.now():%Y%m%d.%H%M%S}.{cls.product_type.capitalize()}.csv"
+ cls._sftp.put(buf, fname)
+ dest = DAILY_DIR / str(datetime.date.today()) / fname
+ dest.write_bytes(buf)
+ cls._mtm_queue.clear()
+
+ def mtm_stage(self):
+ self._mtm_queue.append(self.to_markit())
+
+ @classmethod
+ def from_dict(cls, **kwargs):
+ return cls(**{k: v for k, v in kwargs.items() if k in cls._sql_fields})
+
+
@dataclass
class CDSDeal(
BbgDeal,
+ MTMDeal,
Deal,
deal_type=DealType.CDS,
table_name="cds",
@@ -456,6 +500,7 @@ class BondDeal(BbgDeal, Deal, deal_type=DealType.Bond, table_name="bonds"):
@dataclass
class SwaptionDeal(
+ MTMDeal,
Deal,
deal_type=DealType.Swaption,
table_name="swaptions",