summarylogtreecommitdiffstats
path: root/setup-with-glog.patch
blob: 41e6c3927c3e806ad211c2af45a351c7d1e75d4c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
diff --git a/kt-sft/setup.py b/kt-sft/setup.py
index 632a718..6df384e 100644
--- a/kt-sft/setup.py
+++ b/kt-sft/setup.py
@@ -249,7 +249,7 @@ class VersionInfo:
 
     def get_flash_version(self,):
         version_file = os.path.join(
-            Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py")
+            Path(VersionInfo.THIS_DIR).parent, "version.py")
         with open(version_file, "r", encoding="utf-8") as f:
             version_match = re.search(
                 r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
@@ -633,6 +633,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None:
         'csrc/ktransformers_ext/cuda/binding.cpp',
         'csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
     ],
+    define_macros=[("GLOG_USE_GLOG_EXPORT", None)],
     extra_compile_args={
             'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
             'nvcc': [
@@ -681,6 +682,7 @@ if not torch.xpu.is_available():
                 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
                 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
             ],
+            define_macros=[("GLOG_USE_GLOG_EXPORT", None)],
             extra_compile_args={
                 'cxx': ['-O3'],
                 'nvcc': ['-O3', '-Xcompiler', '-fPIC'],