@@ -13,6 +13,12 @@ use std::path::PathBuf;
13
13
use glob:: glob;
14
14
use which:: which;
15
15
16
+ const PYTHON_PRINT_DIRS : & str = r"
17
+ import sysconfig
18
+ print('PYTHON_INCLUDE_DIR:', sysconfig.get_config_var('INCLUDEDIR'))
19
+ print('PYTHON_LIB_DIR:', sysconfig.get_config_var('LIBDIR'))
20
+ " ;
21
+
16
22
// Translated from torch/utils/cpp_extension.py
17
23
fn find_cuda_home ( ) -> Option < String > {
18
24
// Guess #1
@@ -52,34 +58,50 @@ fn find_cuda_home() -> Option<String> {
52
58
cuda_home
53
59
}
54
60
55
- fn main ( ) {
56
- let cuda_home = find_cuda_home ( ) . expect ( "Could not find CUDA installation" ) ;
57
-
58
- // Tell cargo to look for shared libraries in the CUDA directory
59
- println ! ( "cargo:rustc-link-search={}/lib64" , cuda_home) ;
60
- println ! ( "cargo:rustc-link-search={}/lib" , cuda_home) ;
61
+ fn emit_cuda_link_directives ( cuda_home : & str ) {
62
+ let stubs_path = format ! ( "{}/lib64/stubs" , cuda_home) ;
63
+ if Path :: new ( & stubs_path) . exists ( ) {
64
+ println ! ( "cargo:rustc-link-search=native={}" , stubs_path) ;
65
+ } else {
66
+ let lib64_path = format ! ( "{}/lib64" , cuda_home) ;
67
+ if Path :: new ( & lib64_path) . exists ( ) {
68
+ println ! ( "cargo:rustc-link-search=native={}" , lib64_path) ;
69
+ }
70
+ }
61
71
62
- // Link against the CUDA libraries
63
72
println ! ( "cargo:rustc-link-lib=cuda" ) ;
64
73
println ! ( "cargo:rustc-link-lib=cudart" ) ;
74
+ }
65
75
66
- // Tell cargo to invalidate the built crate whenever the wrapper changes
67
- println ! ( "cargo:rerun-if-changed=src/wrapper.h" ) ;
76
+ fn python_env_dirs ( ) -> ( Option < String > , Option < String > ) {
77
+ let output = std:: process:: Command :: new ( PathBuf :: from ( "python" ) )
78
+ . arg ( "-c" )
79
+ . arg ( PYTHON_PRINT_DIRS )
80
+ . output ( )
81
+ . unwrap_or_else ( |_| panic ! ( "error running python" ) ) ;
68
82
69
- // Add cargo metadata
70
- println ! ( "cargo:rustc-cfg=cargo" ) ;
71
- println ! ( "cargo:rustc-check-cfg=cfg(cargo)" ) ;
83
+ let mut include_dir = None ;
84
+ let mut lib_dir = None ;
85
+ for line in String :: from_utf8_lossy ( & output. stdout ) . lines ( ) {
86
+ if let Some ( path) = line. strip_prefix ( "PYTHON_INCLUDE_DIR: " ) {
87
+ include_dir = Some ( path. to_string ( ) ) ;
88
+ }
89
+ if let Some ( path) = line. strip_prefix ( "PYTHON_LIB_DIR: " ) {
90
+ lib_dir = Some ( path. to_string ( ) ) ;
91
+ }
92
+ }
93
+ ( include_dir, lib_dir)
94
+ }
72
95
73
- // The bindgen::Builder is the main entry point to bindgen
74
- let bindings = bindgen:: Builder :: default ( )
96
+ fn main ( ) {
97
+ let mut builder = bindgen:: Builder :: default ( )
75
98
// The input header we would like to generate bindings for
76
99
. header ( "src/wrapper.h" )
77
- // Add the CUDA include directory
78
- . clang_arg ( format ! ( "-I{}/include" , cuda_home) )
79
- // Parse as C++
80
100
. clang_arg ( "-x" )
81
101
. clang_arg ( "c++" )
82
102
. clang_arg ( "-std=gnu++20" )
103
+ . clang_arg ( format ! ( "-I{}/include" , find_cuda_home( ) . unwrap( ) ) )
104
+ . parse_callbacks ( Box :: new ( bindgen:: CargoCallbacks :: new ( ) ) )
83
105
// Allow the specified functions and types
84
106
. allowlist_function ( "cu.*" )
85
107
. allowlist_function ( "CU.*" )
@@ -89,16 +111,33 @@ fn main() {
89
111
. default_enum_style ( bindgen:: EnumVariation :: NewType {
90
112
is_bitfield : false ,
91
113
is_global : false ,
92
- } )
93
- // Finish the builder and generate the bindings
94
- . parse_callbacks ( Box :: new ( bindgen:: CargoCallbacks :: new ( ) ) )
95
- . generate ( )
96
- // Unwrap the Result and panic on failure
97
- . expect ( "Unable to generate bindings" ) ;
114
+ } ) ;
115
+
116
+ // Include headers and libs from the active environment.
117
+ let ( include_dir, lib_dir) = python_env_dirs ( ) ;
118
+ if let Some ( include_dir) = include_dir {
119
+ builder = builder. clang_arg ( format ! ( "-I{}" , include_dir) ) ;
120
+ }
121
+ if let Some ( lib_dir) = lib_dir {
122
+ println ! ( "cargo::rustc-link-search=native={}" , lib_dir) ;
123
+ // Set cargo metadata to inform dependent binaries about how to set their
124
+ // RPATH (see controller/build.rs for an example).
125
+ println ! ( "cargo::metadata=LIB_PATH={}" , lib_dir) ;
126
+ }
127
+ if let Some ( cuda_home) = find_cuda_home ( ) {
128
+ emit_cuda_link_directives ( & cuda_home) ;
129
+ }
98
130
99
131
// Write the bindings to the $OUT_DIR/bindings.rs file
100
132
let out_path = PathBuf :: from ( env:: var ( "OUT_DIR" ) . unwrap ( ) ) ;
101
- bindings
133
+ builder
134
+ . generate ( )
135
+ . expect ( "Unable to generate bindings" )
102
136
. write_to_file ( out_path. join ( "bindings.rs" ) )
103
137
. expect ( "Couldn't write bindings!" ) ;
138
+
139
+ println ! ( "cargo:rustc-link-lib=cuda" ) ;
140
+ println ! ( "cargo:rustc-link-lib=cudart" ) ;
141
+ println ! ( "cargo::rustc-cfg=cargo" ) ;
142
+ println ! ( "cargo::rustc-check-cfg=cfg(cargo)" ) ;
104
143
}
0 commit comments