summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pyisda/credit_index.pyx31
1 files changed, 13 insertions, 18 deletions
diff --git a/pyisda/credit_index.pyx b/pyisda/credit_index.pyx
index 207a8e0..b2c8913 100644
--- a/pyisda/credit_index.pyx
+++ b/pyisda/credit_index.pyx
@@ -50,17 +50,15 @@ cdef TContingentLeg* copyContingentLeg(TContingentLeg* leg) nogil:
cdef class CurveList:
@cython.initializedcheck(False)
- def __init__(self, list curves not None, double[:] weights=None, value_date=None):
+ def __init__(self, list curves not None, value_date=None):
cdef:
SpreadCurve sc
size_t i
map[string, size_t].iterator it
size_t n = len(curves)
+ double w
- if isinstance(curves[0], SpreadCurve):
- sc = <SpreadCurve>curves[0]
- else:
- raise TypeError("curves need to be a list of SpreadCurve")
+ w, sc = curves[0]
if value_date is not None:
self.base_date = pydate_to_TDate(value_date)
@@ -69,30 +67,28 @@ cdef class CurveList:
i = 0
cdef int n_skipped = 0
- for sc in curves:
+ cdef double total_weight = 0.
+ for w, sc in curves:
if sc is not None:
it = self.tickers.find(sc.ticker)
if it == self.tickers.end():
self.tickers[sc.ticker] = i
self._curves.push_back(sc._thisptr)
self.recovery_rates.push_back(sc.recovery_rates)
- self._weights.push_back(1.)
+ self._weights.push_back(w)
self.defaulted.push_back(sc.defaulted)
i += 1
else:
- self._weights[deref(it).second] += 1
+ self._weights[deref(it).second] += w
+ total_weight += w
else:
n_skipped += 1
- if weights is not None:
- for i in range(weights.shape[0]):
- self._weights[i] = weights[i]
- else:
- for i in range(self._curves.size()):
- self._weights[i] /= (n - n_skipped)
-
if n_skipped > 0:
warnings.warn(f"skipped {n_skipped} empty curves")
+ # we rescale the weights
+ for i in range(self._weights.size()):
+ self._weights[i] *= total_weight
def __getitem__(self, str ticker):
cdef:
@@ -208,9 +204,8 @@ cdef class CurveList:
@cython.auto_pickle(False)
cdef class CreditIndex(CurveList):
- def __init__(self, start_date, maturities, list curves, double[:] weights=None,
- value_date=None):
- CurveList.__init__(self, curves, weights, value_date)
+ def __init__(self, start_date, maturities, list curves, value_date=None):
+ CurveList.__init__(self, curves, value_date)
self.start_date = pydate_to_TDate(start_date)
for d in maturities:
self._maturities.push_back(pydate_to_TDate(d))