summarylogtreecommitdiffstats
path: root/numpy1.20.patch
blob: 5198d2d7997a857f97e70cd15eec5383f95016b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
From 75ea0b31477d6ba9e990e296bbbd8ca4e7eebadf Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg@google.com>
Date: Fri, 26 Jun 2020 05:08:10 -0700
Subject: [PATCH] Provide overload to cope with const-ness change of NumPy's
 PyUFuncGenericFunction.

See https://github.com/tensorflow/tensorflow/issues/40688, https://github.com/tensorflow/tensorflow/pull/40654.

PiperOrigin-RevId: 318452381
Change-Id: Icc5152f2b020ef19882a49e3c86ac80bbe048d64
---
 tensorflow/python/lib/core/bfloat16.cc | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index feb01f11a1af2..bb6b720febe59 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -517,7 +517,7 @@ bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
 }
 
 template <typename InType, typename OutType, typename Functor>
-void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
+void BinaryUFunc(char** args, const npy_intp* dimensions, const npy_intp* steps,
                  void* data) {
   const char* i0 = args[0];
   const char* i1 = args[1];
@@ -532,11 +532,17 @@ void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
   }
 }
 
+// Numpy changed const-ness of PyUFuncGenericFunction, provide overload.
 template <typename Functor>
 void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
                   void* data) {
   BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
 }
+template <typename Functor>
+void CompareUFunc(char** args, const npy_intp* dimensions,
+                  const npy_intp* steps, void* data) {
+  BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
+}
 
 struct Bfloat16EqFunctor {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }