summarylogtreecommitdiffstats
path: root/sundials6.patch
blob: 74fa1aa843cde123d18d8127158ee62c33fc29e9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index a394b94..ec7cbe4 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -10,7 +10,7 @@ jobs:
     runs-on: ubuntu-latest
     strategy:
       matrix:
-        sundials_version: [2.7.0, 3.2.0]
+        sundials_version: [2.7.0, 3.2.0, 4.1.0, 5.8.0, 6.7.0, 7.0.0]
     steps:
     - uses: actions/checkout@v3
     - name: Install system
diff --git a/src/lib/sundials_callbacks_ida_cvode.pxi b/src/lib/sundials_callbacks_ida_cvode.pxi
index 366c3e2..38b98d4 100644
--- a/src/lib/sundials_callbacks_ida_cvode.pxi
+++ b/src/lib/sundials_callbacks_ida_cvode.pxi
@@ -742,32 +742,64 @@ cdef int ida_jacv(realtype t, N_Vector yy, N_Vector yp, N_Vector rr, N_Vector vv
 
 # Error handling callback functions
 # =================================
+IF SUNDIALS_VERSION >= (7,0,0):
+    cdef extern from "sundials/sundials_context.h":
+        ctypedef _SUNContext * SUNContext
+        cdef struct _SUNContext:
+            pass
+        ctypedef int SUNErrCode
+    cdef void cv_err(int line, const char* func, const char* file, const char* msg, SUNErrCode error_code, void* problem_data, SUNContext sunctx):
+        """
+        This method overrides the default handling of error messages.
+        """
+        cdef ProblemData pData = <ProblemData>problem_data
+        
+        if error_code > 0 and pData.verbose > 0: #Warning
+            print '[CVode Warning]', msg
+        
+        if pData.verbose > 2: #Verbosity is greater than NORMAL, print warnings and errors
+            if error_code < 0: #Error
+                print '[CVode Error]', msg
+ELSE:
+    cdef void cv_err(int error_code, const char *module, const char *function, char *msg, void *problem_data):
+        """
+        This method overrides the default handling of error messages.
+        """
+        cdef ProblemData pData = <ProblemData>problem_data
+        
+        if error_code > 0 and pData.verbose > 0: #Warning
+            print '[CVode Warning]', msg
+        
+        if pData.verbose > 2: #Verbosity is greater than NORMAL, print warnings and errors
+            if error_code < 0: #Error
+                print '[CVode Error]', msg
 
-cdef void cv_err(int error_code, const char *module, const char *function, char *msg, void *problem_data):
-    """
-    This method overrides the default handling of error messages.
-    """
-    cdef ProblemData pData = <ProblemData>problem_data
-    
-    if error_code > 0 and pData.verbose > 0: #Warning
-        print '[CVode Warning]', msg
-    
-    if pData.verbose > 2: #Verbosity is greater than NORMAL, print warnings and errors
-        if error_code < 0: #Error
-            print '[CVode Error]', msg
-            
-cdef void ida_err(int error_code, const char *module, const char *function, char *msg, void *problem_data):
-    """
-    This method overrides the default handling of error messages.
-    """
-    cdef ProblemData pData = <ProblemData>problem_data
-    
-    if error_code > 0 and pData.verbose > 0: #Warning
-        print '[IDA Warning]', msg
-    
-    if pData.verbose > 2: #Verbosity is greater than NORMAL, print warnings and errors
-        if error_code < 0: #Error
-            print '[IDA Error]', msg
+IF SUNDIALS_VERSION >= (7,0,0):
+    cdef void ida_err(int line, const char* func, const char* file, const char* msg, SUNErrCode error_code, void* problem_data, SUNContext sunctx):
+        """
+        This method overrides the default handling of error messages.
+        """
+        cdef ProblemData pData = <ProblemData>problem_data
+        
+        if error_code > 0 and pData.verbose > 0: #Warning
+            print '[IDA Warning]', msg
+        
+        if pData.verbose > 2: #Verbosity is greater than NORMAL, print warnings and errors
+            if error_code < 0: #Error
+                print '[IDA Error]', msg
+ELSE:
+    cdef void ida_err(int error_code, const char *module, const char *function, char *msg, void *problem_data):
+        """
+        This method overrides the default handling of error messages.
+        """
+        cdef ProblemData pData = <ProblemData>problem_data
+        
+        if error_code > 0 and pData.verbose > 0: #Warning
+            print '[IDA Warning]', msg
+        
+        if pData.verbose > 2: #Verbosity is greater than NORMAL, print warnings and errors
+            if error_code < 0: #Error
+                print '[IDA Error]', msg
 
 
 cdef class ProblemData:
diff --git a/src/lib/sundials_callbacks_kinsol.pxi b/src/lib/sundials_callbacks_kinsol.pxi
index fc73035..513077e 100644
--- a/src/lib/sundials_callbacks_kinsol.pxi
+++ b/src/lib/sundials_callbacks_kinsol.pxi
@@ -65,8 +65,13 @@ ELSE:
             return KINDLS_SUCCESS
         except Exception:
             return KINDLS_JACFUNC_RECVR #Recoverable Error (See Sundials description)
-            
-cdef int kin_jacv(N_Vector vv, N_Vector Jv, N_Vector vx, int* new_u,
+
+IF SUNDIALS_VERSION >= (6,0,0):
+    ctypedef bint kin_jacv_bool
+ELSE:
+    ctypedef int kin_jacv_bool
+
+cdef int kin_jacv(N_Vector vv, N_Vector Jv, N_Vector vx, kin_jacv_bool* new_u,
             void *problem_data):
     cdef ProblemDataEquationSolver pData = <ProblemDataEquationSolver>problem_data
     cdef N.ndarray x  = nv2arr(vx)
@@ -217,21 +222,40 @@ ELSE:
             return KIN_SYSFUNC_FAIL
         
         return KIN_SUCCESS
-        
 
-cdef void kin_err(int err_code, const char *module, const char *function, char *msg, void *eh_data):
-    cdef ProblemDataEquationSolver pData = <ProblemDataEquationSolver>eh_data
-    
-    if err_code > 0: #Warning
-        category = 1
-    elif err_code < 0: #Error
-        category = -1
-    else:
-        category = 0
-    
-    print "Error occured in <function: %s>."%function
-    print "<message: %s>"%msg
-    #print "<functionNorm: %g, scaledStepLength: %g, tolerance: %g>"%(fnorm, snorm, pData.TOL)
+IF SUNDIALS_VERSION >= (7,0,0):
+    cdef extern from "sundials/sundials_context.h":
+        ctypedef _SUNContext * SUNContext
+        cdef struct _SUNContext:
+            pass
+        ctypedef int SUNErrCode
+    cdef void kin_err(int line, const char* function, const char* file, const char* msg, SUNErrCode err_code, void* eh_data, SUNContext sunctx):
+        cdef ProblemDataEquationSolver pData = <ProblemDataEquationSolver>eh_data
+        
+        if err_code > 0: #Warning
+            category = 1
+        elif err_code < 0: #Error
+            category = -1
+        else:
+            category = 0
+        
+        print "Error occured in <function: %s>."%function
+        print "<message: %s>"%msg
+        #print "<functionNorm: %g, scaledStepLength: %g, tolerance: %g>"%(fnorm, snorm, pData.TOL)
+ELSE:
+    cdef void kin_err(int err_code, const char *module, const char *function, char *msg, void *eh_data):
+        cdef ProblemDataEquationSolver pData = <ProblemDataEquationSolver>eh_data
+        
+        if err_code > 0: #Warning
+            category = 1
+        elif err_code < 0: #Error
+            category = -1
+        else:
+            category = 0
+        
+        print "Error occured in <function: %s>."%function
+        print "<message: %s>"%msg
+        #print "<functionNorm: %g, scaledStepLength: %g, tolerance: %g>"%(fnorm, snorm, pData.TOL)
 
 
 cdef void kin_info(const char *module, const char *function, char *msg, void *eh_data):
diff --git a/src/lib/sundials_includes.pxd b/src/lib/sundials_includes.pxd
index f1e43ac..ea4c7c1 100644
--- a/src/lib/sundials_includes.pxd
+++ b/src/lib/sundials_includes.pxd
@@ -40,9 +40,22 @@ IF SUNDIALS_VERSION >= (6,0,0):
             pass
         int SUNContext_Create(void* comm, SUNContext* ctx)
 
-cdef extern from "sundials/sundials_types.h":
+IF SUNDIALS_VERSION >= (7,0,0):
+    cdef extern from "sundials/sundials_context.h":
+        ctypedef int SUNErrCode
+        ctypedef void (*SUNErrHandlerFn)(int line, const char* func, const char* file, const char* msg, SUNErrCode err_code, void* err_user_data, SUNContext sunctx)
+        SUNErrCode SUNContext_PushErrHandler(SUNContext sunctx, SUNErrHandlerFn err_fn, void* err_user_data)
+
+IF SUNDIALS_VERSION >= (6,0,0):
+    cdef extern from "sundials/sundials_types.h":
+        ctypedef double sunrealtype
+        ctypedef bint sunbooleantype
     ctypedef double realtype
-    ctypedef bint booleantype # should be bool instead of bint, but there is a bug in Cython
+    ctypedef bint booleantype
+ELSE:
+    cdef extern from "sundials/sundials_types.h":
+        ctypedef double realtype
+        ctypedef bint booleantype # should be bool instead of bint, but there is a bug in Cython
 
 cdef extern from "sundials/sundials_nvector.h":
     ctypedef _generic_N_Vector* N_Vector
@@ -333,12 +346,7 @@ cdef extern from "cvodes/cvodes.h":
     
     #Functions for retrieving results
     int CVodeGetDky(void *cvode_mem, realtype t, int k, N_Vector dky)
-    
-    #Functions for error handling
-    ctypedef void (*CVErrHandlerFn)(int error_code, const char *module, const char *function, char *msg,
-                                   void *eh_data)
-    int CVodeSetErrHandlerFn(void *cvode_mem, CVErrHandlerFn ehfun, void* eh_data)
-    
+
     #Functions for discontinuity handling
     ctypedef int (*CVRootFn)(realtype tt, N_Vector yy, realtype *gout, void *user_data)
     int CVodeRootDirection(void *cvode_mem, int *rootdir)
@@ -408,11 +416,31 @@ cdef extern from "cvodes/cvodes.h":
     int CVodeGetStgrSensNumNonlinSolvIters(void *cvode_mem, long int *nSTGR1niters)
     int CVodeGetStgrSensNumNonlinSolvConvFails(void *cvode_mem, long int *nSTGR1ncfails)
 
-cdef extern from "cvodes/cvodes_spils.h":
-    ctypedef int (*CVSpilsJacTimesVecFn)(N_Vector v, N_Vector Jv, realtype t,
-            N_Vector y, N_Vector fy, void *user_data, N_Vector tmp)
+IF SUNDIALS_VERSION >= (6,0,0):
+    cdef extern from "cvodes/cvodes_ls.h":
+        ctypedef int (*CVLsJacFn)(sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix Jac, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
+        ctypedef int (*CVLsPrecSolveFn)(sunrealtype t, N_Vector y, N_Vector fy, N_Vector r, N_Vector z, sunrealtype gamma, sunrealtype delta, int lr, void* user_data);
+        ctypedef int (*CVLsJacTimesSetupFn)(realtype t, N_Vector y, N_Vector fy, void *user_data)
+        ctypedef int (*CVLsJacTimesVecFn)(N_Vector v, N_Vector Jv, realtype t, N_Vector y, N_Vector fy, void *user_data, N_Vector tmp)
+ELSE:
+    cdef extern from "cvodes/cvodes_spils.h":
+        ctypedef int (*CVSpilsJacTimesVecFn)(N_Vector v, N_Vector Jv, realtype t, N_Vector y, N_Vector fy, void *user_data, N_Vector tmp)
 
-IF SUNDIALS_VERSION >= (3,0,0):
+
+IF SUNDIALS_VERSION >= (6,0,0):
+    cdef extern from "cvode/cvode_ls.h":
+       int CVodeSetJacTimes(void *cvode_mem, CVLsJacTimesSetupFn jtsetup, CVLsJacTimesVecFn jtimes)
+       int CVodeSetLinearSolver(void* cvode_mem, SUNLinearSolver LS, SUNMatrix A)
+       int CVodeSetJacFn(void* cvode_mem, CVLsJacFn jac)
+    IF SUNDIALS_WITH_SUPERLU:
+        cdef extern from "sunlinsol/sunlinsol_superlumt.h":
+            SUNLinearSolver SUNLinSol_SuperLUMT(N_Vector y, SUNMatrix A, int num_threads, SUNContext ctx)
+    ELSE:
+        cdef inline SUNLinearSolver SUNLinSol_SuperLUMT(N_Vector y, SUNMatrix A, int num_threads, SUNContext ctx): return NULL
+
+    cdef inline int cv_spils_jtsetup_dummy(realtype t, N_Vector y, N_Vector fy, void *user_data): return 0
+    cdef inline tuple version(): return (6,0,0)
+ELIF SUNDIALS_VERSION >= (3,0,0):
     cdef extern from "cvodes/cvodes_direct.h":
         ctypedef int (*CVDlsDenseJacFn)(realtype t, N_Vector y, N_Vector fy, 
                        SUNMatrix Jac, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
@@ -435,21 +463,13 @@ IF SUNDIALS_VERSION >= (3,0,0):
 				  N_Vector r, N_Vector z,
 				  realtype gamma, realtype delta, int lr, void *user_data)
 
-
     IF SUNDIALS_WITH_SUPERLU:
-        IF SUNDIALS_VERSION >= (6,0,0):
-            cdef extern from "sunlinsol/sunlinsol_superlumt.h":
-                SUNLinearSolver SUNLinSol_SuperLUMT(N_Vector y, SUNMatrix A, int num_threads, SUNContext ctx)
-        ELSE:
-            cdef extern from "sunlinsol/sunlinsol_superlumt.h":
-                SUNLinearSolver SUNSuperLUMT(N_Vector y, SUNMatrix A, int num_threads)
+        cdef extern from "sunlinsol/sunlinsol_superlumt.h":
+            SUNLinearSolver SUNSuperLUMT(N_Vector y, SUNMatrix A, int num_threads)
     ELSE:
-        IF SUNDIALS_VERSION >= (6,0,0):
-            cdef inline SUNLinearSolver SUNLinSol_SuperLUMT(N_Vector y, SUNMatrix A, int num_threads, SUNContext ctx): return NULL
-        ELSE:
-            cdef inline SUNLinearSolver SUNSuperLUMT(N_Vector y, SUNMatrix A, int num_threads): return NULL
+         cdef inline SUNLinearSolver SUNSuperLUMT(N_Vector y, SUNMatrix A, int num_threads): return NULL
     
-    cdef inline int cv_spils_jtsetup_dummy(realtype t, N_Vector y, N_Vector fy, void *user_data): return 0    
+    cdef inline int cv_spils_jtsetup_dummy(realtype t, N_Vector y, N_Vector fy, void *user_data): return 0
     cdef inline tuple version(): return (3,0,0)
 ELSE:
     cdef extern from "cvodes/cvodes_dense.h":
@@ -494,20 +514,29 @@ ELSE:
         cdef inline int CVSlsSetSparseJacFn(void *cvode_mem, CVSlsSparseJacFn jac): return -1
         cdef inline int CVSlsGetNumJacEvals(void *cvode_mem, long int *njevals): return -1
         cdef inline tuple version(): return (2,5,0)
-    
-cdef extern from "cvodes/cvodes_spils.h":
-    IF SUNDIALS_VERSION >= (4,0,0):
-        int CVodeSetPreconditioner(void *cvode_mem, CVSpilsPrecSetupFn psetup, CVSpilsPrecSolveFn psolve)
+
+IF SUNDIALS_VERSION >= (6,0,0):
+    cdef extern from "cvode/cvode_ls.h":
+        ctypedef int (*CVLsPrecSetupFn)(sunrealtype t, N_Vector y, N_Vector fy, sunbooleantype jok, sunbooleantype* jcurPtr, sunrealtype gamma, void* user_data)
+        int CVodeSetPreconditioner(void* cvode_mem, CVLsPrecSetupFn pset, CVLsPrecSolveFn psolve)
         int CVodeGetNumJtimesEvals(void *cvode_mem, long int *njvevals) #Number of jac*vector evals
         int CVodeGetNumRhsEvals(void *cvode_mem, long int *nfevalsLS) #Number of res evals due to jacÄvector evals
         int CVodeGetNumPrecEvals(void *cvode_mem, long int *npevals)
         int CVodeGetNumPrecSolves(void *cvode_mem, long int *npsolves)
-    ELSE:
-        int CVSpilsSetPreconditioner(void *cvode_mem, CVSpilsPrecSetupFn psetup, CVSpilsPrecSolveFn psolve)
-        int CVSpilsGetNumJtimesEvals(void *cvode_mem, long int *njvevals) #Number of jac*vector evals
-        int CVSpilsGetNumRhsEvals(void *cvode_mem, long int *nfevalsLS) #Number of res evals due to jacÄvector evals
-        int CVSpilsGetNumPrecEvals(void *cvode_mem, long int *npevals)
-        int CVSpilsGetNumPrecSolves(void *cvode_mem, long int *npsolves)
+ELSE:
+    cdef extern from "cvodes/cvodes_spils.h":
+        IF SUNDIALS_VERSION >= (4,0,0):
+            int CVodeSetPreconditioner(void *cvode_mem, CVSpilsPrecSetupFn psetup, CVSpilsPrecSolveFn psolve)
+            int CVodeGetNumJtimesEvals(void *cvode_mem, long int *njvevals) #Number of jac*vector evals
+            int CVodeGetNumRhsEvals(void *cvode_mem, long int *nfevalsLS) #Number of res evals due to jacÄvector evals
+            int CVodeGetNumPrecEvals(void *cvode_mem, long int *npevals)
+            int CVodeGetNumPrecSolves(void *cvode_mem, long int *npsolves)
+        ELSE:
+            int CVSpilsSetPreconditioner(void *cvode_mem, CVSpilsPrecSetupFn psetup, CVSpilsPrecSolveFn psolve)
+            int CVSpilsGetNumJtimesEvals(void *cvode_mem, long int *njvevals) #Number of jac*vector evals
+            int CVSpilsGetNumRhsEvals(void *cvode_mem, long int *nfevalsLS) #Number of res evals due to jacÄvector evals
+            int CVSpilsGetNumPrecEvals(void *cvode_mem, long int *npevals)
+            int CVSpilsGetNumPrecSolves(void *cvode_mem, long int *npsolves)
 
 cdef extern from "idas/idas.h":
     ctypedef int (*IDAResFn)(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, void *user_data)
@@ -536,13 +565,7 @@ cdef extern from "idas/idas.h":
     
     #Functions for retrieving results
     int IDAGetDky(void *ida_mem, realtype t, int k, N_Vector dky)
-    
-    #Functions for error handling
-    ctypedef void (*IDAErrHandlerFn)(int error_code, const char *module, const char *function, char *msg,
-                                    void *eh_data)
-    int IDASetErrHandlerFn(void *ida_mem,IDAErrHandlerFn ehfun, void* eh_data)
-    
-    
+
     #Functions for discontinuity handling
     ctypedef int (*IDARootFn)(realtype tt, N_Vector yy, N_Vector yp, realtype *gout, void *user_data)
     int IDASetRootDirection(void *ida_mem, int *rootdir)
@@ -615,12 +638,25 @@ cdef extern from "idas/idas.h":
     
     #End Sensitivities
     #=================
-    
-cdef extern from "idas/idas_spils.h":
-    ctypedef int (*IDASpilsJacTimesVecFn)(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, 
-            N_Vector v, N_Vector Jv, realtype cj, void *user_data,N_Vector tmp1, N_Vector tmp2)
-    
-IF SUNDIALS_VERSION >= (3,0,0):
+
+IF SUNDIALS_VERSION >= (6,0,0):
+    cdef extern from "ida/ida_ls.h":
+        ctypedef int (*IDALsJacFn)(sunrealtype t, sunrealtype c_j, N_Vector y, N_Vector yp, N_Vector r, SUNMatrix Jac, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
+        ctypedef int (*IDALsJacTimesSetupFn)(sunrealtype tt, N_Vector yy, N_Vector yp, N_Vector rr, sunrealtype c_j, void* user_data)
+        ctypedef int (*IDALsJacTimesVecFn)(sunrealtype tt, N_Vector yy, N_Vector yp, N_Vector rr, N_Vector v, N_Vector Jv, sunrealtype c_j, void* user_data, N_Vector tmp1, N_Vector tmp2)
+ELSE:
+    cdef extern from "idas/idas_spils.h":
+        ctypedef int (*IDASpilsJacTimesVecFn)(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, 
+                N_Vector v, N_Vector Jv, realtype cj, void *user_data,N_Vector tmp1, N_Vector tmp2)
+
+IF SUNDIALS_VERSION >= (6,0,0):
+    cdef extern from "ida/ida_ls.h":
+        int IDASetJacFn(void* ida_mem, IDALsJacFn jac)
+        int IDASetLinearSolver(void* ida_mem, SUNLinearSolver LS, SUNMatrix A)
+        int IDASetJacTimes(void* ida_mem, IDALsJacTimesSetupFn jtsetup, IDALsJacTimesVecFn jtimes)
+
+    cdef inline int ida_spils_jtsetup_dummy(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, realtype c_j, void *user_data): return 0
+ELIF SUNDIALS_VERSION >= (3,0,0):
     cdef extern from "idas/idas_direct.h":
         ctypedef int (*IDADlsDenseJacFn)(realtype tt, realtype cj, N_Vector yy, 
                        N_Vector yp, N_Vector rr, SUNMatrix Jac, void *user_data, 
@@ -660,13 +696,18 @@ ELSE:
     cdef extern from "idas/idas_spils.h":
         int IDASpilsSetJacTimesVecFn(void *ida_mem, IDASpilsJacTimesVecFn ida_jacv)
 
-cdef extern from "idas/idas_spils.h":
-    IF SUNDIALS_VERSION >= (4,0,0):
+IF SUNDIALS_VERSION >= (6,0,0):
+    cdef extern from "ida/ida_ls.h": 
         int IDAGetNumJtimesEvals(void *ida_mem, long int *njvevals) #Number of jac*vector
         int IDAGetNumResEvals(void *ida_mem, long int *nfevalsLS) #Number of rhs due to jac*vector
-    ELSE:
-        int IDASpilsGetNumJtimesEvals(void *ida_mem, long int *njvevals) #Number of jac*vector
-        int IDASpilsGetNumResEvals(void *ida_mem, long int *nfevalsLS) #Number of rhs due to jac*vector
+ELSE:
+    cdef extern from "idas/idas_spils.h":
+        IF SUNDIALS_VERSION >= (4,0,0):
+            int IDAGetNumJtimesEvals(void *ida_mem, long int *njvevals) #Number of jac*vector
+            int IDAGetNumResEvals(void *ida_mem, long int *nfevalsLS) #Number of rhs due to jac*vector
+        ELSE:
+            int IDASpilsGetNumJtimesEvals(void *ida_mem, long int *njvevals) #Number of jac*vector
+            int IDASpilsGetNumResEvals(void *ida_mem, long int *nfevalsLS) #Number of rhs due to jac*vector
 
 
 ####################
@@ -678,7 +719,6 @@ cdef extern from "idas/idas_spils.h":
 cdef extern from "kinsol/kinsol.h":
     # user defined functions
     ctypedef int (*KINSysFn)(N_Vector uu, N_Vector fval, void *user_data )
-    ctypedef void (*KINErrHandlerFn)(int error_code, char *module, char *function, char *msg, void *user_data)
     ctypedef void (*KINInfoHandlerFn)(const char *module, const char *function, char *msg, void *user_data)
     # initialization routines
     IF SUNDIALS_VERSION >= (6,0,0):
@@ -689,7 +729,6 @@ cdef extern from "kinsol/kinsol.h":
 
     # optional input spec. functions,
     # for specificationsdocumentation cf. kinsol.h line 218-449
-    int KINSetErrHandlerFn(void *kinmem, KINErrHandlerFn ehfun, void *eh_data)
     int KINSetInfoHandlerFn(void *kinmem, KINInfoHandlerFn ihfun, void *ih_data)
     int KINSetUserData(void *kinmem, void *user_data)
     int KINSetPrintLevel(void *kinmemm, int printfl)
@@ -729,8 +768,24 @@ cdef extern from "kinsol/kinsol.h":
     # fuction used to deallocate memory used by KINSOL
     void KINFree(void **kinmem)
 
+# Functions for error handling
+IF SUNDIALS_VERSION < (7,0,0):
+    cdef extern from "kinsol/kinsol.h":
+        ctypedef void (*KINErrHandlerFn)(int error_code, char *module, char *function, char *msg, void *user_data)
+        int KINSetErrHandlerFn(void *kinmem, KINErrHandlerFn ehfun, void *eh_data)
+    cdef extern from "cvodes/cvodes.h":
+        ctypedef void (*CVErrHandlerFn)(int error_code, const char *module, const char *function, char *msg, void *eh_data)
+        int CVodeSetErrHandlerFn(void *cvode_mem, CVErrHandlerFn ehfun, void* eh_data)
+    cdef extern from "idas/idas.h":
+        ctypedef void (*IDAErrHandlerFn)(int error_code, const char *module, const char *function, char *msg, void *eh_data) 
+        int IDASetErrHandlerFn(void *ida_mem,IDAErrHandlerFn ehfun, void* eh_data)
 
-IF SUNDIALS_VERSION >= (3,0,0):
+IF SUNDIALS_VERSION >= (6,0,0):
+    cdef extern from "kinsol/kinsol_ls.h":
+        ctypedef int (*KINLsJacFn)(N_Vector u, N_Vector fu, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2)
+        int KINSetLinearSolver(void* kinmem, SUNLinearSolver LS, SUNMatrix A)
+        int KINSetJacFn(void* kinmem, KINLsJacFn jac)
+ELIF SUNDIALS_VERSION >= (3,0,0):
     cdef extern from "kinsol/kinsol_direct.h":
         ctypedef int (*KINDlsDenseJacFn)(N_Vector u, N_Vector fu, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2)
         IF SUNDIALS_VERSION < (4,0,0):
@@ -767,40 +822,57 @@ ELSE:
         ctypedef int (*KINSpilsPrecSetupFn)(N_Vector u, N_Vector uscale,
                     N_Vector fval, N_Vector fscale, void *problem_data, N_Vector tmp1, N_Vector tmp2)
 
-cdef extern from "kinsol/kinsol_direct.h":
-    # optional output fcts for linear direct solver
-    int KINDlsGetWorkSpace(void *kinmem, long int *lenrwB, long int *leniwB)
-    IF SUNDIALS_VERSION >= (4,0,0):
+IF SUNDIALS_VERSION >= (6,0,0):
+    cdef extern from "kinsol/kinsol_ls.h":
+        ctypedef int (*KINLsPrecSetupFn)(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, void* user_data);
+        ctypedef int (*KINLsPrecSolveFn)(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, N_Vector vv, void* user_data)
+        ctypedef int (*KINLsJacTimesVecFn)(N_Vector v, N_Vector Jv, N_Vector uu, sunbooleantype* new_uu, void* J_data)
         int KINGetLastLinFlag(void *kinmem, long int *flag)
         int KINGetNumJacEvals(void *kinmem, long int *njevalsB)
         int KINGetNumFuncEvals(void *kinmem, long int *nfevalsB)
-    ELSE:
-        int KINDlsGetLastFlag(void *kinmem, long int *flag)
-        int KINDlsGetNumJacEvals(void *kinmem, long int *njevalsB)
-        int KINDlsGetNumFuncEvals(void *kinmem, long int *nfevalsB)
-    char *KINDlsGetReturnFlagName(int flag)
-
-cdef extern from "kinsol/kinsol_spils.h":
-    ctypedef int (*KINSpilsJacTimesVecFn)(N_Vector vv, N_Vector Jv, N_Vector vx, int* new_u,
-                void *problem_data)
-    IF SUNDIALS_VERSION >= (4,0,0):
-        int KINSetJacTimesVecFn(void *kinmem, KINSpilsJacTimesVecFn jacv)
-        int KINSetPreconditioner(void *kinmem, KINSpilsPrecSetupFn psetup, KINSpilsPrecSolveFn psolve)
+        int KINSetJacTimesVecFn(void* kinmem, KINLsJacTimesVecFn jtv)
+        int KINSetPreconditioner(void* kinmem, KINLsPrecSetupFn psetup, KINLsPrecSolveFn psolve)
         int KINGetNumLinIters(void *kinmem, long int *nliters)
         int KINGetNumLinConvFails(void *kinmem, long int *nlcfails)
         int KINGetNumPrecEvals(void *kinmem, long int *npevals)
         int KINGetNumPrecSolves(void *kinmem, long int *npsolves)
         int KINGetNumJtimesEvals(void *kinmem, long int *njevals)
         int KINGetNumFuncEvals(void *kinmem, long int *nfevalsLS)
-    ELSE:
-        int KINSpilsSetJacTimesVecFn(void *kinmem, KINSpilsJacTimesVecFn jacv)
-        int KINSpilsSetPreconditioner(void *kinmem, KINSpilsPrecSetupFn psetup, KINSpilsPrecSolveFn psolve)
-        int KINSpilsGetNumLinIters(void *kinmem, long int *nliters)
-        int KINSpilsGetNumConvFails(void *kinmem, long int *nlcfails)
-        int KINSpilsGetNumPrecEvals(void *kinmem, long int *npevals)
-        int KINSpilsGetNumPrecSolves(void *kinmem, long int *npsolves)
-        int KINSpilsGetNumJtimesEvals(void *kinmem, long int *njevals)
-        int KINSpilsGetNumFuncEvals(void *kinmem, long int *nfevalsLS)
+ELSE:
+    cdef extern from "kinsol/kinsol_direct.h":
+        # optional output fcts for linear direct solver
+        int KINDlsGetWorkSpace(void *kinmem, long int *lenrwB, long int *leniwB)
+        IF SUNDIALS_VERSION >= (4,0,0):
+            int KINGetLastLinFlag(void *kinmem, long int *flag)
+            int KINGetNumJacEvals(void *kinmem, long int *njevalsB)
+            int KINGetNumFuncEvals(void *kinmem, long int *nfevalsB)
+        ELSE:
+            int KINDlsGetLastFlag(void *kinmem, long int *flag)
+            int KINDlsGetNumJacEvals(void *kinmem, long int *njevalsB)
+            int KINDlsGetNumFuncEvals(void *kinmem, long int *nfevalsB)
+        char *KINDlsGetReturnFlagName(int flag)
+
+    cdef extern from "kinsol/kinsol_spils.h":
+        ctypedef int (*KINSpilsJacTimesVecFn)(N_Vector vv, N_Vector Jv, N_Vector vx, int* new_u,
+                    void *problem_data)
+        IF SUNDIALS_VERSION >= (4,0,0):
+            int KINSetJacTimesVecFn(void *kinmem, KINSpilsJacTimesVecFn jacv)
+            int KINSetPreconditioner(void *kinmem, KINSpilsPrecSetupFn psetup, KINSpilsPrecSolveFn psolve)
+            int KINGetNumLinIters(void *kinmem, long int *nliters)
+            int KINGetNumLinConvFails(void *kinmem, long int *nlcfails)
+            int KINGetNumPrecEvals(void *kinmem, long int *npevals)
+            int KINGetNumPrecSolves(void *kinmem, long int *npsolves)
+            int KINGetNumJtimesEvals(void *kinmem, long int *njevals)
+            int KINGetNumFuncEvals(void *kinmem, long int *nfevalsLS)
+        ELSE:
+            int KINSpilsSetJacTimesVecFn(void *kinmem, KINSpilsJacTimesVecFn jacv)
+            int KINSpilsSetPreconditioner(void *kinmem, KINSpilsPrecSetupFn psetup, KINSpilsPrecSolveFn psolve)
+            int KINSpilsGetNumLinIters(void *kinmem, long int *nliters)
+            int KINSpilsGetNumConvFails(void *kinmem, long int *nlcfails)
+            int KINSpilsGetNumPrecEvals(void *kinmem, long int *npevals)
+            int KINSpilsGetNumPrecSolves(void *kinmem, long int *npsolves)
+            int KINSpilsGetNumJtimesEvals(void *kinmem, long int *njevals)
+            int KINSpilsGetNumFuncEvals(void *kinmem, long int *nfevalsLS)
 
 #=========================
 # END SUNDIALS DEFINITIONS
diff --git a/src/solvers/kinsol.pyx b/src/solvers/kinsol.pyx
index c0dbaac..32ed0f0 100644
--- a/src/solvers/kinsol.pyx
+++ b/src/solvers/kinsol.pyx
@@ -210,7 +210,10 @@ cdef class KINSOL(Algebraic):
                 raise KINSOLError(flag)
             
             #Specify the error handling
-            flag = SUNDIALS.KINSetErrHandlerFn(self.kinsol_mem, kin_err, <void*>self.pData)
+            IF SUNDIALS_VERSION >= (7,0,0):
+                flag = SUNDIALS.SUNContext_PushErrHandler(ctx, kin_err, <void*>self.pData)
+            ELSE:
+                flag = SUNDIALS.KINSetErrHandlerFn(self.kinsol_mem, kin_err, <void*>self.pData)
             if flag < 0:
                 raise KINSOLError(flag) 
                 
diff --git a/src/solvers/sundials.pyx b/src/solvers/sundials.pyx
index 85aa459..20b96be 100644
--- a/src/solvers/sundials.pyx
+++ b/src/solvers/sundials.pyx
@@ -342,7 +342,10 @@ cdef class IDA(Implicit_ODE):
                     raise IDAError(flag,self.t)
             
             #Specify the error handling
-            flag = SUNDIALS.IDASetErrHandlerFn(self.ida_mem, ida_err, <void*>self.pData)
+            IF SUNDIALS_VERSION >= (7,0,0):
+                flag = SUNDIALS.SUNContext_PushErrHandler(ctx, ida_err, <void*>self.pData)
+            ELSE:
+                flag = SUNDIALS.IDASetErrHandlerFn(self.ida_mem, ida_err, <void*>self.pData)
             if flag < 0:
                 raise IDAError(flag, self.t)
                 
@@ -1843,7 +1846,10 @@ cdef class CVode(Explicit_ODE):
                     raise CVodeError(flag, self.t)
                     
             #Specify the error handling
-            flag = SUNDIALS.CVodeSetErrHandlerFn(self.cvode_mem, cv_err, <void*>self.pData)
+            IF SUNDIALS_VERSION >= (7,0,0):
+                flag = SUNDIALS.SUNContext_PushErrHandler(ctx, cv_err, <void*>self.pData)
+            ELSE:
+                flag = SUNDIALS.CVodeSetErrHandlerFn(self.cvode_mem, cv_err, <void*>self.pData)
             if flag < 0:
                 raise CVodeError(flag, self.t)