summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pyisda/curve.pxd2
-rw-r--r--pyisda/curve.pyx20
2 files changed, 14 insertions, 8 deletions
diff --git a/pyisda/curve.pxd b/pyisda/curve.pxd
index 54193a8..db247cb 100644
--- a/pyisda/curve.pxd
+++ b/pyisda/curve.pxd
@@ -238,7 +238,7 @@ cpdef enum Fmt:
LZ4 = 1
Packed = 2
-cdef void forward_rates(const TRatePt* arr, const TDate base_date, const int n, double* f, int *t) noexcept nogil
+cdef void forward_rates(double basis, const TRatePt* arr, const TDate base_date, const int n, double* f, int *t) noexcept nogil
cdef class CurveShock:
cdef vector[int] t
diff --git a/pyisda/curve.pyx b/pyisda/curve.pyx
index 813dbd7..e467d9e 100644
--- a/pyisda/curve.pyx
+++ b/pyisda/curve.pyx
@@ -61,15 +61,21 @@ cdef double survival_prob(const TCurve* curve, TDate start_date, TDate maturity_
else:
return exp(u)
-cdef void forward_rates(const TRatePt* arr, const TDate base_date, const int n, double* f, int *t) noexcept nogil:
+cdef void forward_rates(double basis, const TRatePt* arr, const TDate base_date, const int n, double* f, int *t) noexcept nogil:
cdef:
size_t i
t[0] = arr[0].fDate - base_date
- f[0] = arr[0].fRate
- for i in range(1, n):
- t[i] = arr[i].fDate - base_date
- f[i] = (arr[i].fRate * t[i] - arr[i-1].fRate * t[i-1]) / (t[i] - t[i-1])
+ if basis == <double>Basis.CONTINUOUS:
+ f[0] = arr[0].fRate
+ for i in range(1, n):
+ t[i] = arr[i].fDate - base_date
+ f[i] = (arr[i].fRate * t[i] - arr[i-1].fRate * t[i-1]) / (t[i] - t[i-1])
+ elif basis == <double>Basis.ANNUAL_BASIS:
+ f[0] = log1p(arr[0].fRate)
+ for i in range(1, n):
+ t[i] = arr[i].fDate - base_date
+ f[i] = (log1p(arr[i].fRate) * t[i] - log1p(arr[i-1].fRate) * t[i-1]) / (t[i] - t[i-1])
cdef class Curve(object):
@@ -646,8 +652,8 @@ cdef class YieldCurve(Curve):
f2 = <double*>malloc(n2 * sizeof(double))
if curve2.fBaseDate >= curve1.fBaseDate:
raise ValueError()
- forward_rates(curve1.fArray, curve1.fBaseDate, n1, f1, t1)
- forward_rates(curve2.fArray, curve2.fBaseDate, n2, f2, t2)
+ forward_rates(curve1.fBasis, curve1.fArray, curve1.fBaseDate, n1, f1, t1)
+ forward_rates(curve2.fBasis, curve2.fArray, curve2.fBaseDate, n2, f2, t2)
i = 0
j = 0