diff options
| author | Guillaume Horel <guillaume.horel@gmail.com> | 2019-02-14 16:31:49 -0500 |
|---|---|---|
| committer | Guillaume Horel <guillaume.horel@gmail.com> | 2019-02-14 16:31:49 -0500 |
| commit | f7cd1a424daa2cef75fc028931599e51fe707536 (patch) | |
| tree | 3b923ae22df9b25090b5ae4f2afbdd1cde99b250 | |
| parent | 5643fadff431acefd9f34f4997da5ec9ec05dae3 (diff) | |
| download | pyisda-f7cd1a424daa2cef75fc028931599e51fe707536.tar.gz | |
get rid of the separate weights parameter
| -rw-r--r-- | pyisda/credit_index.pyx | 31 |
1 files changed, 13 insertions, 18 deletions
diff --git a/pyisda/credit_index.pyx b/pyisda/credit_index.pyx index 207a8e0..b2c8913 100644 --- a/pyisda/credit_index.pyx +++ b/pyisda/credit_index.pyx @@ -50,17 +50,15 @@ cdef TContingentLeg* copyContingentLeg(TContingentLeg* leg) nogil: cdef class CurveList: @cython.initializedcheck(False) - def __init__(self, list curves not None, double[:] weights=None, value_date=None): + def __init__(self, list curves not None, value_date=None): cdef: SpreadCurve sc size_t i map[string, size_t].iterator it size_t n = len(curves) + double w - if isinstance(curves[0], SpreadCurve): - sc = <SpreadCurve>curves[0] - else: - raise TypeError("curves need to be a list of SpreadCurve") + w, sc = curves[0] if value_date is not None: self.base_date = pydate_to_TDate(value_date) @@ -69,30 +67,28 @@ cdef class CurveList: i = 0 cdef int n_skipped = 0 - for sc in curves: + cdef double total_weight = 0. + for w, sc in curves: if sc is not None: it = self.tickers.find(sc.ticker) if it == self.tickers.end(): self.tickers[sc.ticker] = i self._curves.push_back(sc._thisptr) self.recovery_rates.push_back(sc.recovery_rates) - self._weights.push_back(1.) + self._weights.push_back(w) self.defaulted.push_back(sc.defaulted) i += 1 else: - self._weights[deref(it).second] += 1 + self._weights[deref(it).second] += w + total_weight += w else: n_skipped += 1 - if weights is not None: - for i in range(weights.shape[0]): - self._weights[i] = weights[i] - else: - for i in range(self._curves.size()): - self._weights[i] /= (n - n_skipped) - if n_skipped > 0: warnings.warn(f"skipped {n_skipped} empty curves") + # we rescale the weights + for i in range(self._weights.size()): + self._weights[i] *= total_weight def __getitem__(self, str ticker): cdef: @@ -208,9 +204,8 @@ cdef class CurveList: @cython.auto_pickle(False) cdef class CreditIndex(CurveList): - def __init__(self, start_date, maturities, list curves, double[:] weights=None, - value_date=None): - CurveList.__init__(self, curves, weights, value_date) + def __init__(self, start_date, maturities, list curves, value_date=None): + CurveList.__init__(self, curves, value_date) self.start_date = pydate_to_TDate(start_date) for d in maturities: self._maturities.push_back(pydate_to_TDate(d)) |
