summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pyisda/curve.pyx26
1 files changed, 17 insertions, 9 deletions
diff --git a/pyisda/curve.pyx b/pyisda/curve.pyx
index 24014c3..bc79413 100644
--- a/pyisda/curve.pyx
+++ b/pyisda/curve.pyx
@@ -16,6 +16,7 @@ import numpy as np
np.import_array()
import pandas as pd
from cpython cimport Py_buffer
+from numpy cimport npy_intp
cdef extern from "Python.h":
int PyMemoryView_Check(object)
@@ -23,6 +24,9 @@ cdef extern from "Python.h":
cdef extern from "numpy/arrayobject.h":
void PyArray_ENABLEFLAGS(np.ndarray arr, int flags)
+ int PyArray_CheckExact(object)
+ void* PyArray_DATA(object)
+ npy_intp PyArray_Size(object)
cdef int SUCCESS = 0
@@ -456,25 +460,29 @@ cdef class SpreadCurve(Curve):
cdef TDate step_in_date_c = pydate_to_TDate(step_in_date)
cdef TDate cash_settle_date_c = pydate_to_TDate(cash_settle_date)
cdef TDate start_date_c = pydate_to_TDate(start_date)
- cdef int n_dates = len(end_dates)
+ cdef int n_dates
cdef TDate* end_dates_c = NULL
- cdef TDate[:] end_dates_view
cdef TCurve* curve = NULL
cdef unsigned int includes = 0
cdef size_t i
+ cdef int freeup = 0
if cash_settle_date_c < yc._thisptr.get().fBaseDate:
raise ValueError("cash_settle_date: {0} is anterior to yc's base_date: {1}".
format(cash_settle_date, yc.base_date))
if isinstance(end_dates, list):
+ n_dates = len(end_dates)
end_dates_c = <TDate*>malloc(n_dates * sizeof(TDate))
- end_dates_view = <TDate[:n_dates]>end_dates_c
- for i, d in enumerate(end_dates):
- end_dates_view[i] = pydate_to_TDate(d)
+ freeup = 1
+ i = 0
+ for d in end_dates:
+ end_dates_c[i] = pydate_to_TDate(d)
if upfront_rates[i] == upfront_rates[i]:
includes |= 1 << i
- else:
- end_dates_view = memoryview(end_dates)
+ i += 1
+ elif PyArray_CheckExact(end_dates):
+ end_dates_c = <TDate*>PyArray_DATA(end_dates)
+ n_dates = PyArray_Size(end_dates)
for i in range(n_dates):
if upfront_rates[i] == upfront_rates[i]:
includes |= 1 << i
@@ -489,7 +497,7 @@ cdef class SpreadCurve(Curve):
step_in_date_c,
cash_settle_date_c,
n_dates,
- &end_dates_view[0],
+ end_dates_c,
&coupon_rates[0],
&upfront_rates[0],
includes,
@@ -500,7 +508,7 @@ cdef class SpreadCurve(Curve):
&stub_type,
<long>'M',
b'NONE')
- if end_dates_c:
+ if freeup:
free(end_dates_c)
if curve == NULL:
raise ValueError("Didn't init the survival curve properly")