1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
diff --git a/kt-sft/setup.py b/kt-sft/setup.py
index 632a718..6df384e 100644
--- a/kt-sft/setup.py
+++ b/kt-sft/setup.py
@@ -249,7 +249,7 @@ class VersionInfo:
def get_flash_version(self,):
version_file = os.path.join(
- Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py")
+ Path(VersionInfo.THIS_DIR).parent, "version.py")
with open(version_file, "r", encoding="utf-8") as f:
version_match = re.search(
r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
@@ -633,6 +633,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None:
'csrc/ktransformers_ext/cuda/binding.cpp',
'csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
+ define_macros=[("GLOG_USE_GLOG_EXPORT", None)],
extra_compile_args={
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
'nvcc': [
@@ -681,6 +682,7 @@ if not torch.xpu.is_available():
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
],
+ define_macros=[("GLOG_USE_GLOG_EXPORT", None)],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
|