diff options
| -rw-r--r-- | pyisda/curve.pxd | 2 | ||||
| -rw-r--r-- | pyisda/curve.pyx | 20 |
2 files changed, 14 insertions, 8 deletions
diff --git a/pyisda/curve.pxd b/pyisda/curve.pxd index 54193a8..db247cb 100644 --- a/pyisda/curve.pxd +++ b/pyisda/curve.pxd @@ -238,7 +238,7 @@ cpdef enum Fmt: LZ4 = 1 Packed = 2 -cdef void forward_rates(const TRatePt* arr, const TDate base_date, const int n, double* f, int *t) noexcept nogil +cdef void forward_rates(double basis, const TRatePt* arr, const TDate base_date, const int n, double* f, int *t) noexcept nogil cdef class CurveShock: cdef vector[int] t diff --git a/pyisda/curve.pyx b/pyisda/curve.pyx index 813dbd7..e467d9e 100644 --- a/pyisda/curve.pyx +++ b/pyisda/curve.pyx @@ -61,15 +61,21 @@ cdef double survival_prob(const TCurve* curve, TDate start_date, TDate maturity_ else: return exp(u) -cdef void forward_rates(const TRatePt* arr, const TDate base_date, const int n, double* f, int *t) noexcept nogil: +cdef void forward_rates(double basis, const TRatePt* arr, const TDate base_date, const int n, double* f, int *t) noexcept nogil: cdef: size_t i t[0] = arr[0].fDate - base_date - f[0] = arr[0].fRate - for i in range(1, n): - t[i] = arr[i].fDate - base_date - f[i] = (arr[i].fRate * t[i] - arr[i-1].fRate * t[i-1]) / (t[i] - t[i-1]) + if basis == <double>Basis.CONTINUOUS: + f[0] = arr[0].fRate + for i in range(1, n): + t[i] = arr[i].fDate - base_date + f[i] = (arr[i].fRate * t[i] - arr[i-1].fRate * t[i-1]) / (t[i] - t[i-1]) + elif basis == <double>Basis.ANNUAL_BASIS: + f[0] = log1p(arr[0].fRate) + for i in range(1, n): + t[i] = arr[i].fDate - base_date + f[i] = (log1p(arr[i].fRate) * t[i] - log1p(arr[i-1].fRate) * t[i-1]) / (t[i] - t[i-1]) cdef class Curve(object): @@ -646,8 +652,8 @@ cdef class YieldCurve(Curve): f2 = <double*>malloc(n2 * sizeof(double)) if curve2.fBaseDate >= curve1.fBaseDate: raise ValueError() - forward_rates(curve1.fArray, curve1.fBaseDate, n1, f1, t1) - forward_rates(curve2.fArray, curve2.fBaseDate, n2, f2, t2) + forward_rates(curve1.fBasis, curve1.fArray, curve1.fBaseDate, n1, f1, t1) + forward_rates(curve2.fBasis, curve2.fArray, curve2.fBaseDate, n2, f2, t2) i = 0 j = 0 |
