diff options
| author | Guillaume Horel <guillaume.horel@gmail.com> | 2023-04-27 16:39:35 -0400 |
|---|---|---|
| committer | Guillaume Horel <guillaume.horel@gmail.com> | 2023-04-27 16:39:35 -0400 |
| commit | fa12e9399e4a7c33f4eb184a8b551e557a019e92 (patch) | |
| tree | 2894abed70ff0bda556fd3be691e0ac18f2cbf71 | |
| parent | 3b79c5de75370c6fd30adbfe72c1e70d9ffe791c (diff) | |
| download | pyisda-fa12e9399e4a7c33f4eb184a8b551e557a019e92.tar.gz | |
fix packed format for SpreadCurve
| -rw-r--r-- | pyisda/curve.pyx | 41 |
1 files changed, 16 insertions, 25 deletions
diff --git a/pyisda/curve.pyx b/pyisda/curve.pyx index 104a98d..c3263b0 100644 --- a/pyisda/curve.pyx +++ b/pyisda/curve.pyx @@ -108,32 +108,18 @@ cdef class Curve(object): return r def __getstate__(self): - return self.as_bytes(True) + return self.as_bytes(0) def __setstate__(self, bytes state not None): cdef: const char* src = state - int decomp_size = 512 - char* curve = NULL - int state_size size_t size = PyBytes_GET_SIZE(state) - int retry = 0 - bytes r + char* curve = <char*>malloc(size) + with nogil: - while True: - curve = <char*>realloc(curve, decomp_size) - state_size = LZ4_decompress_safe(src, curve, size, decomp_size) - if state_size < 0: - retry += 1 - if retry == 2: - free(curve) - raise MemoryError("something went wrong") - else: - decomp_size *= 2 - else: - break - self.buf_size = state_size + memcpy(curve, src, size) + self.buf_size = size self.buf.reset(curve, char_free) @classmethod @@ -143,6 +129,7 @@ cdef class Curve(object): Py_buffer* py_buf char* src char* curve = NULL + char* cursor int size, state_size int decomp_size = 512 int retry = 0 @@ -177,15 +164,19 @@ cdef class Curve(object): state_size = size elif fmt == 2: n = (<TCurve*>src).fNumItems - state_size = sizeof(TCurve) + n * (sizeof(TDate) + sizeof(double)) + state_size = size + n * (sizeof(TDate) - sizeof(uint16_t)) curve = <char*>malloc(state_size) memcpy(curve, src, sizeof(TCurve)) base_date = (<TCurve*>src).fBaseDate - curve += sizeof(TCurve) + cursor = curve + sizeof(TCurve) src += sizeof(TCurve) for i in range(n): - (<TRatePt*>curve)[i] = TRatePt(base_date + (<short*>src)[i], (<double*>(src + n * sizeof(short)))[i]) - curve -= sizeof(TCurve) + (<TRatePt*>cursor)[i] = TRatePt(base_date + (<uint16_t*>src)[i], (<double*>(src + n * sizeof(uint16_t)))[i]) + decomp_size = state_size - (sizeof(TCurve) + n * sizeof(TRatePt)) + if decomp_size > 0: + cursor += n * sizeof(TRatePt) + src += n * (sizeof(double) + sizeof(uint16_t)) + memcpy(cursor, src, decomp_size) instance.buf_size = state_size instance.buf.reset(curve, char_free) return instance @@ -868,9 +859,9 @@ cdef class SpreadCurve(Curve): curve.fArray[i].fRate = JPMCDS_MAX_RATE @classmethod - def from_bytes(cls, object state, const bint compressed=False): + def from_bytes(cls, object state, const int fmt=0): cdef: - SpreadCurve instance = super().from_bytes(state, compressed) + SpreadCurve instance = super().from_bytes(state, fmt) char* cursor = instance.buf.get() TCurve* curve = <TCurve*>cursor int n = curve.fNumItems |
