diff options
Diffstat (limited to 'python/analytics/option.py')
| -rw-r--r-- | python/analytics/option.py | 34 |
1 files changed, 14 insertions, 20 deletions
diff --git a/python/analytics/option.py b/python/analytics/option.py index 0e92f6c7..ac7e8230 100644 --- a/python/analytics/option.py +++ b/python/analytics/option.py @@ -737,7 +737,7 @@ def _compute_vol(option, strike, mid): def _calibrate_model( - index, quotes, option_type, option_model, interp_method="bivariate_spline" + option_class, index, quotes, option_type, interp_method="bivariate_spline" ): """ interp_method : one of 'bivariate_spline', 'bivariate_linear' @@ -748,7 +748,7 @@ def _calibrate_model( quotes = quotes.sort_values("strike", ascending=False) with Pool(4) as p: for expiry, df in quotes.groupby(["expiry"]): - option = option_model(index, expiry.date(), 100, option_type) + option = option_class(index, expiry.date(), 100, option_type) T.append(option.T) r.append( np.stack( @@ -796,7 +796,7 @@ class ModelBasedVolSurface(VolSurface): tenor="5yr", value_date=datetime.date.today(), interp_method="bivariate_spline", - **kwargs + **kwargs, ): super().__init__(index_type, series, tenor, value_date) self._index = CreditIndex(index_type, series, tenor, value_date, notional=1.0) @@ -806,17 +806,11 @@ class ModelBasedVolSurface(VolSurface): pay_mid=self._quotes[["pay_bid", "pay_offer"]].mean(1) * 1e-4, rec_mid=self._quotes[["rec_bid", "rec_offer"]].mean(1) * 1e-4, ) - if type(self) is BlackSwaptionVolSurface: - self._opts = {"option_model": BlackSwaption, "interp_method": interp_method} - elif type(self) is SwaptionVolSurface: - self._opts = {"option_model": Swaption} - elif type(self) is SABRVolSurface: - self._opts = {"beta": 3.19 if index_type == "HY" else 1.84} - else: - raise TypeError( - "class needs to be SwaptionVolSurface, " - "BlackSwaptionVolSurface or SABRVolSurface" - ) + self._calibrator = partial(self._calibrator, interp_method=interp_method) + + def __init_subclass__(cls, /, option_model, **kwargs): + cls._calibrator = partial(_calibrate_model, option_model) + super().__init_subclass__(**kwargs) def list(self, source=None, option_type=None): """returns list of vol surfaces""" @@ -838,8 +832,8 @@ class ModelBasedVolSurface(VolSurface): ) self._index.ref = quotes.ref.iat[0] self._index_refs[surface_id] = quotes.ref.iat[0] - self._surfaces[surface_id] = _calibrate( - self._index, quotes, option_type, **self._opts + self._surfaces[surface_id] = self._calibrator( + self._index, quotes, option_type ) return self._surfaces[surface_id] else: @@ -866,16 +860,16 @@ class ModelBasedVolSurface(VolSurface): ax.set_zlabel("Volatility") -class BlackSwaptionVolSurface(ModelBasedVolSurface): +class BlackSwaptionVolSurface(ModelBasedVolSurface, option_model=BlackSwaption): pass -class SwaptionVolSurface(ModelBasedVolSurface): +class SwaptionVolSurface(ModelBasedVolSurface, option_model=Swaption): pass -class SABRVolSurface(ModelBasedVolSurface): - pass +# class SABRVolSurface(ModelBasedVolSurface, opts={"beta": 3.19 if index_type == "HY" else 1.84}): +# pass @lru_cache(maxsize=32) |
