Skip to content

Commit

Permalink
fix build on mac for real now?
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Nov 12, 2024
1 parent d04d689 commit a3783b1
Showing 1 changed file with 54 additions and 2 deletions.
56 changes: 54 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from glob import glob

from setuptools import Extension, setup
from setuptools._distutils._log import log
from setuptools._distutils._modified import newer_group
from setuptools.command.build_ext import build_ext

CWD = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -55,14 +57,64 @@
sources=["klujax.cpp", *suitesparse_sources],
include_dirs=include_dirs,
library_dirs=site.getsitepackages(),
extra_compile_args=[], # clang defaults to c++17 and setting -std=c++17 prevents combined build with suitesparse c source.
extra_compile_args=["-std=c++17"],
extra_link_args=[],
language="c++",
)
else:
raise RuntimError(f"Platform {sys.platform} not supported.")


# Custom BuildExt to enable combined build of C and C++ files on MacOs (clang)
# However, this class also removes some warnings when used on linux (gcc) and
# Windows (cl) so we use it everywhere.
class BuildExt(build_ext):
def build_extension(self, ext):
sources = ext.sources
c_sources = sorted([s for s in sources if s.endswith("c")])
cpp_sources = sorted([s for s in sources if s not in c_sources])
ext_path = self.get_ext_fullpath(ext.name)
macros = ext.define_macros[:]
for undef in ext.undef_macros:
macros.append((undef,))
c_objects = self.compiler.compile(
c_sources,
output_dir=self.build_temp,
macros=macros,
include_dirs=ext.include_dirs,
debug=self.debug,
extra_postargs=[
f
for f in ext.extra_compile_args
if not f in ["-std=c++17", "/std:c++17"] # THIS IS OUR HACK
],
depends=ext.depends,
)
cpp_objects = self.compiler.compile(
cpp_sources,
output_dir=self.build_temp,
macros=macros,
include_dirs=ext.include_dirs,
debug=self.debug,
extra_postargs=ext.extra_compile_args,
depends=ext.depends,
)
objects = c_objects + cpp_objects
extra_args = ext.extra_link_args or []
self.compiler.link_shared_object(
objects,
ext_path,
libraries=self.get_libraries(ext),
library_dirs=ext.library_dirs,
runtime_library_dirs=ext.runtime_library_dirs,
extra_postargs=extra_args,
export_symbols=self.get_export_symbols(ext),
debug=self.debug,
build_temp=self.build_temp,
target_lang=ext.language,
)


setup(
name="klujax",
version="0.2.10",
Expand All @@ -74,7 +126,7 @@
url="https://github.com/flaport/klujax",
py_modules=["klujax"],
ext_modules=[extension],
cmdclass={"build_ext": build_ext},
cmdclass={"build_ext": BuildExt},
install_requires=["jax>=0.4.35", "jaxlib>=0.4.35"],
python_requires=">=3.8",
classifiers=[
Expand Down

0 comments on commit a3783b1

Please sign in to comment.