diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index a479501..f86cea3 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -19,11 +19,11 @@ jobs: matrix: cfg: - conda-env: env - python-version: 3.8 + python-version: 3.11 os: ubuntu-latest - conda-env: env - python-version: 3.8 + python-version: 3.11 os: macOS-latest env: @@ -42,18 +42,10 @@ jobs: ulimit -a - name: Create Environment - uses: conda-incubator/setup-miniconda@v2.1.1 + uses: mamba-org/setup-micromamba@v1 with: - activate-environment: test environment-file: devtools/conda-envs/${{ matrix.cfg.conda-env }}.yaml - python-version: ${{ matrix.cfg.python-version }} - auto-update-conda: true - auto-activate-base: false - show-channel-urls: true - mamba-version: "*" - miniforge-version: latest - miniforge-variant: Mambaforge - use-mamba: true + create-args: python=${{ matrix.python-version }} - name: Environment Information shell: bash -l {0} @@ -64,12 +56,12 @@ jobs: - name: Install NAGL-MBIS shell: bash -l {0} run: | - python setup.py develop --no-deps + pip install -e . --no-build-isolation - name: PyTest shell: bash -l {0} run: | - pytest -v --cov=naglmbis --cov-config=setup.cfg naglmbis/tests/ --cov-report=xml + pytest -v --cov=naglmbis --cov-config=pyproject.toml naglmbis/tests/ --cov-report=xml --color=yes - name: Codecov uses: codecov/codecov-action@v1 diff --git a/.github/workflows/Lint.yaml b/.github/workflows/Lint.yaml index a599508..b47e9f4 100644 --- a/.github/workflows/Lint.yaml +++ b/.github/workflows/Lint.yaml @@ -7,28 +7,32 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7] + python-version: [3.11] steps: - uses: actions/checkout@v2 - name: Python Setup - uses: actions/setup-python@v1 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - - name: Create Environment + - name: Install lint shell: bash run: | - python -m pip install --upgrade pip - python setup.py develop --no-deps + pip install black isort Flake8-pyproject - - name: Install black + - name: Run black shell: bash run: | - pip install black + black naglmbis --check - - name: Run black + - name: Run isort + shell: bash + run: | + isort --check-only naglmbis + + - name: Run flake8 shell: bash run: | - black naglmbis --check \ No newline at end of file + flake8 naglmbis \ No newline at end of file diff --git a/README.md b/README.md index 88d0219..800ee9a 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,10 @@ package by SimonBoothroyd. ## Installation -The required dependencies to run these models can be installed using ``conda``: +The required dependencies to run these models can be installed using ``mamba`` and the provided environment file: ```bash -conda install -c conda-forge -c dglteam nagl "dgl >=0.7" openff-toolkit pytorch-lightning qubekit +mamba env create -f devtools/conda-envs/env.yaml ``` You will then need to install this package from source, first clone the repository from github: @@ -25,22 +25,23 @@ cd nagl-mbis With the nagl environment activate install the models via: ```bash -python setup.py install +pip install -e . --no-build-isolation ``` ## Quick start -NAGL-MBIS offers some pre-trained models to compute conformation independent MBIS charges and volumes, these can be loaded +NAGL-MBIS offers some pre-trained models to compute conformation independent MBIS charges, these can be loaded using the following code in a script ```python -from naglmbis.models import load_volume_model, load_charge_model +from naglmbis.models import load_charge_model -# load two pre-trained models -charge_model = load_charge_model(charge_model=1) -volume_model = load_volume_model(volume_model=1) +# load two pre-trained charge models +charge_model = load_charge_model(charge_model="nagl-v1-mbis") +# load a model trained to scf dipole and mbis charges +charge_model_2 = load_charge_model(charge_model="nagl-v1-mbis-dipole") ``` -we can then use these models to predict the corresponding properties for a given [openff-toolkit](https://github.com/openforcefield/openff-toolkit) [Molecule object](https://docs.openforcefield.org/projects/toolkit/en/stable/users/molecule_cookbook.html#cookbook-every-way-to-make-a-molecule). +we can then use these models to predict the corresponding properties for a given [openff-toolkit](https://github.com/openforcefield/openff-toolkit) [Molecule object](https://docs.openforcefield.org/projects/toolkit/en/stable/users/molecule_cookbook.html#cookbook-every-way-to-make-a-molecule) or rdkit `Chem.Mol`. ```python from openff.toolkit.topology import Molecule @@ -48,10 +49,11 @@ from openff.toolkit.topology import Molecule # create ethanol ethanol = Molecule.from_smiles("CCO") # predict the charges (in e) and atomic volumes in (bohr ^3) -charges = charge_model.compute_properties(ethanol)["mbis-charges"] -volumes = volume_model.compute_properties(ethanol)["mbis-volumes"] +charges = charge_model.compute_properties(ethanol.to_rdkit())["mbis-charges"] +volumes = charge_model_2.compute_properties(ethanol.to_rdkit())["mbis-volumes"] ``` +# This is currently broken, due to plugins changing in the openff stack! Alternatively we provide an openff-toolkit parameter handler plugin which allows you to create an openmm system using the normal python pathway with a modified force field which requests that the ``NAGMBIS`` model be used to predict charges and LJ parameters. We provide a function which can modify any offxml to add the custom handler @@ -70,21 +72,21 @@ openmm_system = nagl_sage.create_openmm_system(topology=methanol.to_topology()) # Models -## MBISGraphModelV1 +## MBISGraphMode This model uses a minimal set of basic atomic features including - one hot encoded element - the number of bonds -- ring membership of size 3-6 +- ring membership of size 3-8 - n_gcn_layers 5 - n_gcn_hidden_features 128 - n_mbis_layers 2 - n_mbis_hidden_features 64 - learning_rate 0.001 -- n_epochs 100 +- n_epochs 1000 These models were trained on the [OpenFF ESP Fragment Conformers v1.0](https://github.com/openforcefield/qca-dataset-submission/tree/master/submissions/2022-01-16-OpenFF-ESP-Fragment-Conformers-v1.0) dataset -which is on QCArchive. The dataset was computed using HF/6-31G* with PSI4. +which is on QCArchive. The dataset was computed using HF/6-31G* with PSI4 and was split 80:10:10 using the deepchem maxmin spliter. diff --git a/devtools/conda-envs/env.yaml b/devtools/conda-envs/env.yaml index cf2bfb0..370b17d 100644 --- a/devtools/conda-envs/env.yaml +++ b/devtools/conda-envs/env.yaml @@ -1,7 +1,6 @@ -name: test +name: naglmbis channels: - conda-forge - - dglteam dependencies: - python @@ -12,16 +11,23 @@ dependencies: - pytest-cov # core deps - - nagl - - openff-toolkit-base <0.11.0, >=0.10.6 +# - nagl >=0.0.9 + - pyarrow + - dgl >=1 + - pytorch + - pytorch-lightning + - rich + - click + - click-option-group + - openff-toolkit-base - qcengine >=0.18.0 - jinja2 - chemper - torsiondrive # - qubekit - - pytorch-lightning - - dgl <1.0.0, >=0.7 - - openff-utilities <=0.1.3 + - openff-utilities + - pydantic <2 - pip: + - git+https://github.com/SimonBoothroyd/nagl.git@main - espaloma_charge - - git+https://github.com/qubekit/QUBEKit.git + - git+https://github.com/qubekit/QUBEKit.git@main diff --git a/naglmbis/__init__.py b/naglmbis/__init__.py index 4fcbbd4..060f7cb 100644 --- a/naglmbis/__init__.py +++ b/naglmbis/__init__.py @@ -2,10 +2,10 @@ naglmbis Models built with NAGL to predict MBIS properties. """ +from . import _version -from ._version import get_versions +__version__ = _version.get_versions()["version"] +# make sure all custom features are loaded +import naglmbis.features -versions = get_versions() -__version__ = versions["version"] -__git_revision__ = versions["full-revisionid"] -del get_versions, versions +__all__ = [naglmbis.features] diff --git a/naglmbis/_version.py b/naglmbis/_version.py index f6eeace..e6b0818 100644 --- a/naglmbis/_version.py +++ b/naglmbis/_version.py @@ -4,19 +4,22 @@ # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" import errno +import functools import os import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -32,17 +35,24 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool -def get_config(): + +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py cfg = VersioneerConfig() cfg.VCS = "git" cfg.style = "pep440" - cfg.tag_prefix = "" + cfg.tag_prefix = "naglmbis-" cfg.parentdir_prefix = "None" - cfg.versionfile_source = "qcsubmit/_version.py" + cfg.versionfile_source = "naglmbis/_version.py" cfg.verbose = False return cfg @@ -51,14 +61,14 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -68,24 +78,39 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, + process = subprocess.Popen( + [command] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, ) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -96,18 +121,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -116,7 +143,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return { @@ -126,9 +153,8 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): "error": None, "date": None, } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print( @@ -139,41 +165,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -186,11 +219,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -199,7 +232,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -208,6 +241,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix) :] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r"\d", r): + continue if verbose: print("picking %s" % r) return { @@ -230,7 +268,9 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -241,7 +281,14 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -249,7 +296,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( + describe_out, rc = runner( GITS, [ "describe", @@ -258,7 +305,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): "--always", "--long", "--match", - "%s*" % tag_prefix, + f"{tag_prefix}[[:digit:]]*", ], cwd=root, ) @@ -266,16 +313,48 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -292,7 +371,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces @@ -318,26 +397,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -361,23 +441,70 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -404,12 +531,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -426,7 +582,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -446,7 +602,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -466,7 +622,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return { @@ -482,10 +638,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -504,7 +664,7 @@ def render(pieces, style): } -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -524,7 +684,7 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split("/"): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: return { diff --git a/naglmbis/data/models/charge/mbis_charges_v1.ckpt b/naglmbis/data/models/charge/mbis_charges_v1.ckpt deleted file mode 100644 index fc0a143..0000000 Binary files a/naglmbis/data/models/charge/mbis_charges_v1.ckpt and /dev/null differ diff --git a/naglmbis/data/models/charge/nagl-v1-mbis-dipole.ckpt b/naglmbis/data/models/charge/nagl-v1-mbis-dipole.ckpt new file mode 100644 index 0000000..a5a5d41 Binary files /dev/null and b/naglmbis/data/models/charge/nagl-v1-mbis-dipole.ckpt differ diff --git a/naglmbis/data/models/charge/nagl-v1-mbis.ckpt b/naglmbis/data/models/charge/nagl-v1-mbis.ckpt new file mode 100644 index 0000000..8853bef Binary files /dev/null and b/naglmbis/data/models/charge/nagl-v1-mbis.ckpt differ diff --git a/naglmbis/data/models/volume/mbis_volumes_v1.ckpt b/naglmbis/data/models/volume/mbis_volumes_v1.ckpt deleted file mode 100644 index 74466d4..0000000 Binary files a/naglmbis/data/models/volume/mbis_volumes_v1.ckpt and /dev/null differ diff --git a/naglmbis/features/__init__.py b/naglmbis/features/__init__.py index e69de29..b1347e3 100644 --- a/naglmbis/features/__init__.py +++ b/naglmbis/features/__init__.py @@ -0,0 +1,31 @@ +from naglmbis.features.atom import ( + AtomicMass, + AtomicPolarisability, + AtomInRingOfSize, + ExplicitValence, + Hybridization, + HydrogenAtoms, + LipinskiAcceptor, + LipinskiDonor, + PaulingElectronegativity, + SandersonElectronegativity, + TotalDegree, + TotalValence, + VDWRadius, +) + +__all__ = [ + AtomicMass, + AtomicPolarisability, + AtomInRingOfSize, + ExplicitValence, + Hybridization, + HydrogenAtoms, + LipinskiAcceptor, + LipinskiDonor, + PaulingElectronegativity, + SandersonElectronegativity, + TotalDegree, + TotalValence, + VDWRadius, +] diff --git a/naglmbis/features/atom.py b/naglmbis/features/atom.py index 9b5a152..140842d 100644 --- a/naglmbis/features/atom.py +++ b/naglmbis/features/atom.py @@ -1,73 +1,29 @@ -from typing import List, Optional +from typing import Literal import torch -from nagl.features import AtomFeature, one_hot_encode -from openff.toolkit.topology import Molecule - -# from nagl.resonance import enumerate_resonance_forms -# from nagl.utilities.toolkits import normalize_molecule +from nagl.features import AtomFeature, one_hot_encode, register_atom_feature +from pydantic import Extra, Field, dataclasses from rdkit import Chem -# class AtomAverageFormalCharge(AtomFeature): -# def __call__(self, molecule: "Molecule") -> torch.Tensor: -# try: -# molecule = normalize_molecule(molecule) -# except AssertionError: -# pass -# -# resonance_forms = enumerate_resonance_forms( -# molecule=molecule, -# lowest_energy_only=True, -# as_dicts=True, -# include_all_transfer_pathways=True, -# ) -# formal_charges = [ -# [ -# atom["formal_charge"] -# for resonance_form in resonance_forms -# if i in resonance_forms["atoms"] -# for atom in resonance_form["atoms"][i] -# ] -# for i in range(molecule.n_atoms) -# ] -# feature_tensor = torch.tensor( -# [ -# [ -# sum(formal_charges[i]) / len(formal_charges[i]) -# if len(formal_charges[i]) > 0 -# else 0.0 -# ] -# for i in range(molecule.n_atoms) -# ] -# ) -# return feature_tensor -# -# def __len__(self): -# return 1 - +@dataclasses.dataclass(config={"extra": Extra.forbid}) class HydrogenAtoms(AtomFeature): """One hot encode the number of bonded hydrogen atoms""" - _HYDROGENS = [0, 1, 2, 3, 4] - - def __init__(self, hydrogens: Optional[List[int]] = None) -> None: - self.hydrogens = hydrogens if hydrogens is not None else [*self._HYDROGENS] + type: Literal["hydrogenatoms"] = "hydrogenatoms" + hydrogens: list[int] = Field( + [0, 1, 2, 3, 4], + description="The options for the number of bonded hydrogens to one hot encode.", + ) - def __call__(self, molecule: "Molecule") -> torch.Tensor: + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: return torch.vstack( [ one_hot_encode( - sum( - [ - n.atomic_number - for n in atom.bonded_atoms - if n.atomic_number == 1 - ] - ), + atom.GetTotalNumHs(includeNeighbors=True), self.hydrogens, ) - for atom in molecule.atoms + for atom in molecule.GetAtoms() ] ) @@ -75,126 +31,113 @@ def __len__(self): return len(self.hydrogens) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class AtomInRingOfSize(AtomFeature): - def __init__(self, ring_size: int) -> None: - assert ring_size >= 3 - self.ring_size = ring_size - - def __call__(self, molecule: "Molecule") -> torch.Tensor: - rd_molecule: Chem.Mol = molecule.to_rdkit() - ring_info: Chem.RingInfo = rd_molecule.GetRingInfo() - - return torch.tensor( - [ - int(ring_info.IsAtomInRingOfSize(atom.GetIdx(), self.ring_size)) - for atom in rd_molecule.GetAtoms() - ] - ).reshape(-1, 1) + type: Literal["ringofsize"] = "ringofsize" + ring_sizes: list[int] = Field( + [3, 4, 5, 6, 7, 8], + description="The ring of size we want to check membership of.", + ) - def __len__(self): - return 1 - - -class BondInRingOfSize(AtomFeature): - def __init__(self, ring_size: int): - assert ring_size >= 3 - self.ring_size = ring_size - - def __call__(self, molecule: "Molecule") -> torch.Tensor: - rd_molecule: Chem.Mol = molecule.to_rdkit() - ring_info: Chem.RingInfo = rd_molecule.GetRingInfo() + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: + ring_info: Chem.RingInfo = molecule.GetRingInfo() - rd_bond_by_index = { - tuple(sorted((rd_bond.GetBgnIdx(), rd_bond.GetEndIdx()))): rd_bond - for rd_bond in rd_molecule.GetBonds() - } - - rd_bonds = [ - rd_bond_by_index[tuple(sorted((bond.atom1_index, bond.atom2_index)))] - for bond in molecule.bonds - ] - - return torch.tensor( + return torch.vstack( [ - int(ring_info.IsBondInRingOfSize(rd_bond, self.ring_size)) - for rd_bond in rd_bonds + torch.Tensor( + [ + int(ring_info.IsAtomInRingOfSize(atom.GetIdx(), ring_size)) + for ring_size in self.ring_sizes + ] + ) + for atom in molecule.GetAtoms() ] - ).reshape(-1, 1) + ) def __len__(self): - return 1 + return len(self.ring_sizes) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class LipinskiDonor(AtomFeature): """ Return if the atom is a Lipinski h-bond donor. """ + type: Literal["lipinskidonor"] = "lipinskidonor" + def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs) -> torch.Tensor: + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: from rdkit.Chem import Lipinski - rd_molecule: Chem.Mol = molecule.to_rdkit() - donors = Lipinski._HDonors(rd_molecule) + donors = Lipinski._HDonors(molecule) # squash the lists donors = [d for donor in donors for d in donor] return torch.tensor( - [int(atom.GetIdx() in donors) for atom in rd_molecule.GetAtoms()] + [int(atom.GetIdx() in donors) for atom in molecule.GetAtoms()] ).reshape(-1, 1) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class LipinskiAcceptor(AtomFeature): """ Return if the atom is a Lipinski h-bond acceptor. """ + type: Literal["lipinskiacceptor"] = "lipinskiacceptor" + def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs): + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: from rdkit.Chem import Lipinski - rd_molecule: Chem.Mol = molecule.to_rdkit() - acceptors = Lipinski._HAcceptors(rd_molecule) + acceptors = Lipinski._HAcceptors(molecule) # squash the lists acceptors = [a for acceptor in acceptors for a in acceptor] return torch.tensor( - [int(atom.GetIdx() in acceptors) for atom in rd_molecule.GetAtoms()] + [int(atom.GetIdx() in acceptors) for atom in molecule.GetAtoms()] ).reshape(-1, 1) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class PaulingElectronegativity(AtomFeature): """ Return the pauling electronegativity of each of the atoms. """ + type: Literal["paulingelectronegativity"] = "paulingelectronegativity" # values taken from - _negativities = { - 1: 2.2, - 5: 2.04, - 6: 2.55, - 7: 3.04, - 8: 3.44, - 9: 3.98, - 14: 1.9, - 15: 2.19, - 16: 2.58, - 17: 3.16, - 35: 2.96, - 53: 2.66, - } + negativities: dict[int, float] = Field( + { + 1: 2.2, + 5: 2.04, + 6: 2.55, + 7: 3.04, + 8: 3.44, + 9: 3.98, + 14: 1.9, + 15: 2.19, + 16: 2.58, + 17: 3.16, + 35: 2.96, + 53: 2.66, + }, + description="The reference negativities for each element.", + ) def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs): + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: return torch.tensor( - [self._negativities[atom.atomic_number] for atom in molecule.atoms] + [self.negativities[atom.GetAtomicNum()] for atom in molecule.GetAtoms()] ).reshape(-1, 1) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class SandersonElectronegativity(AtomFeature): """ Return the Sanderson electronegativity of each of the atoms. @@ -202,156 +145,192 @@ class SandersonElectronegativity(AtomFeature): Values taken from """ - _negativities = { - 1: 2.59, - 5: 2.28, - 6: 2.75, - 7: 3.19, - 8: 3.65, - 9: 4.0, - 14: 2.14, - 15: 2.52, - 16: 2.96, - 17: 3.48, - 35: 3.22, - 53: 2.78, - } + type: Literal["sandersonelectronegativity"] = "sandersonelectronegativity" + negativities: dict[int, float] = Field( + { + 1: 2.59, + 5: 2.28, + 6: 2.75, + 7: 3.19, + 8: 3.65, + 9: 4.0, + 14: 2.14, + 15: 2.52, + 16: 2.96, + 17: 3.48, + 35: 3.22, + 53: 2.78, + }, + description="The reference negativities for each element.", + ) def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs): + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: return torch.Tensor( - [self._negativities[atom.atomic_number] for atom in molecule.atoms] + [self.negativities[atom.GetAtomicNum()] for atom in molecule.GetAtoms()] ).reshape(-1, 1) -class vdWRadius(AtomFeature): +@dataclasses.dataclass(config={"extra": Extra.forbid}) +class VDWRadius(AtomFeature): """ Return the vdW radius of the atom. Values taken from """ - _radii = { - 1: 1.17, - 5: 1.62, - 6: 1.75, - 7: 1.55, - 8: 1.4, - 9: 1.3, - 14: 1.97, - 15: 1.85, - 16: 1.8, - 17: 1.75, - 35: 1.95, - 53: 2.1, - } + type: Literal["vdwradius"] = "vdwradius" + radii: dict[int, float] = Field( + { + 1: 1.17, + 5: 1.62, + 6: 1.75, + 7: 1.55, + 8: 1.4, + 9: 1.3, + 14: 1.97, + 15: 1.85, + 16: 1.8, + 17: 1.75, + 35: 1.95, + 53: 2.1, + }, + description="The reference vdW radii in angstroms for each element.", + ) def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs): + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: return torch.Tensor( - [self._radii[atom.atomic_number] for atom in molecule.atoms] + [self.radii[atom.GetAtomicNum()] for atom in molecule.GetAtoms()] ).reshape(-1, 1) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class AtomicPolarisability(AtomFeature): """Assign the atomic polarisability for each atom. values from """ - _polarisability = { - 1: 0.67, - 5: 3.03, - 6: 1.76, - 7: 1.1, - 8: 1.1, - 9: 0.56, - 14: 5.38, - 15: 3.63, - 16: 2.9, - 17: 2.18, - 35: 3.05, - 53: 5.35, - } + type: Literal["atomicpolarisability"] = "atomicpolarisability" + polarisability: dict[int, float] = Field( + { + 1: 0.67, + 5: 3.03, + 6: 1.76, + 7: 1.1, + 8: 1.1, + 9: 0.56, + 14: 5.38, + 15: 3.63, + 16: 2.9, + 17: 2.18, + 35: 3.05, + 53: 5.35, + }, + description="The atomic polarisability in atomic units for each element.", + ) def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs): + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: return torch.Tensor( - [self._polarisability[atom.atomic_number] for atom in molecule.atoms] + [self.polarisability[atom.GetAtomicNum()] for atom in molecule.GetAtoms()] ).reshape(-1, 1) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class Hybridization(AtomFeature): """ one hot encode the rdkit hybridization of the atom. """ - _HYBRIDIZATION = [ - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2, - Chem.rdchem.HybridizationType.S, - ] - - def __init__(self, hybridization: Optional[List[int]] = None) -> None: - self.hybridization = ( - hybridization if hybridization is not None else [*self._HYBRIDIZATION] - ) + type: Literal["hybridization"] = "hybridization" + hybridization: list[Chem.rdchem.HybridizationType] = Field( + [ + Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + Chem.rdchem.HybridizationType.S, + ], + description="The list of hybridization types which we can one hot encode", + ) def __len__(self): - return len(self._HYBRIDIZATION) + return len(self.hybridization) - def __call__(self, molecule: Molecule, *args, **kwargs): + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: return torch.vstack( [ one_hot_encode(atom.GetHybridization(), self.hybridization) - for atom in molecule.to_rdkit().GetAtoms() + for atom in molecule.GetAtoms() ] ) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class TotalValence(AtomFeature): + type: Literal["totalvalence"] = "totalvalence" + def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs): - return torch.Tensor( - [[atom.GetTotalValence()] for atom in molecule.to_rdkit().GetAtoms()] - ) + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: + return torch.Tensor([[atom.GetTotalValence()] for atom in molecule.GetAtoms()]) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class ExplicitValence(AtomFeature): + type: Literal["explicitvalence"] = "explicitvalence" + def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs): + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: return torch.Tensor( - [[atom.GetExplicitValence()] for atom in molecule.to_rdkit().GetAtoms()] + [[atom.GetExplicitValence()] for atom in molecule.GetAtoms()] ) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class AtomicMass(AtomFeature): + type: Literal["atomicmass"] = "atomicmass" + def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs): - return torch.Tensor( - [[atom.GetMass()] for atom in molecule.to_rdkit().GetAtoms()] - ) + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: + return torch.Tensor([[atom.GetMass()] for atom in molecule.GetAtoms()]) +@dataclasses.dataclass(config={"extra": Extra.forbid}) class TotalDegree(AtomFeature): + type: Literal["totaldegree"] = "totaldegree" + def __len__(self): return 1 - def __call__(self, molecule: Molecule, *args, **kwargs): - return torch.Tensor( - [[atom.GetTotalDegree()] for atom in molecule.to_rdkit().GetAtoms()] - ) + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: + return torch.Tensor([[atom.GetTotalDegree()] for atom in molecule.GetAtoms()]) + + +# Register all new features +register_atom_feature(HydrogenAtoms) +register_atom_feature(AtomInRingOfSize) +register_atom_feature(LipinskiDonor) +register_atom_feature(LipinskiAcceptor) +register_atom_feature(PaulingElectronegativity) +register_atom_feature(SandersonElectronegativity) +register_atom_feature(VDWRadius) +register_atom_feature(AtomicPolarisability) +register_atom_feature(Hybridization) +register_atom_feature(TotalValence) +register_atom_feature(ExplicitValence) +register_atom_feature(AtomicMass) +register_atom_feature(TotalDegree) diff --git a/naglmbis/models/__init__.py b/naglmbis/models/__init__.py index da5c28a..a782354 100644 --- a/naglmbis/models/__init__.py +++ b/naglmbis/models/__init__.py @@ -1,6 +1,4 @@ from naglmbis.models.base_model import MBISGraphModel -from naglmbis.models.models import ( - MBISGraphModelV1, - load_charge_model, - load_volume_model, -) +from naglmbis.models.models import CHARGE_MODELS, load_charge_model + +__all__ = [MBISGraphModel, CHARGE_MODELS, load_charge_model] diff --git a/naglmbis/models/base_model.py b/naglmbis/models/base_model.py index e2b9d54..4985f41 100644 --- a/naglmbis/models/base_model.py +++ b/naglmbis/models/base_model.py @@ -1,83 +1,17 @@ # models for the nagl run -import abc -from typing import Dict, List, Literal, Optional - import torch -from nagl.lightning import DGLMoleculeLightningModel from nagl.molecules import DGLMolecule -from nagl.nn import SequentialLayers -from nagl.nn.modules import ConvolutionModule, ReadoutModule -from nagl.nn.pooling import PoolAtomFeatures -from nagl.nn.postprocess import ComputePartialCharges -from openff.toolkit.topology import Molecule +from nagl.training import DGLMoleculeLightningModel +from rdkit import Chem class MBISGraphModel(DGLMoleculeLightningModel): "A wrapper to make it easy to load and evaluate models" - @abc.abstractmethod - def features(self): - ... - - def __init__( - self, - n_gcn_hidden_features: int, - n_gcn_layers: int, - n_mbis_hidden_features: int, - n_mbis_layers: int, - readout_modules: List[Literal["charge", "volume"]], - learning_rate: float, - gcn_activation: Optional[ - str - ] = None, # in the case of SAGEConv no activation is used - readout_activation: str = "ReLU", # defaults for backwards compatibility - ): - self.n_gcn_hidden_features = n_gcn_hidden_features - self.n_gcn_layers = n_gcn_layers - self.n_mbis_hidden_features = n_mbis_hidden_features - self.n_mbis_layers = n_mbis_layers - self.gcn_activation = gcn_activation - self.readout_activation = readout_activation - n_atom_features = sum(len(feature) for feature in self.features()[0]) - readout = {} - if "charge" in readout_modules: - readout["mbis-charges"] = ReadoutModule( - pooling_layer=PoolAtomFeatures(), - readout_layers=SequentialLayers( - in_feats=n_gcn_hidden_features, - hidden_feats=[n_mbis_hidden_features] * n_mbis_layers + [2], - activation=[readout_activation] * n_mbis_layers + ["Identity"], - ), - postprocess_layer=ComputePartialCharges(), - ) - if "volume" in readout_modules: - readout["mbis-volumes"] = ReadoutModule( - pooling_layer=PoolAtomFeatures(), - readout_layers=SequentialLayers( - in_feats=n_gcn_hidden_features, - hidden_feats=[n_mbis_hidden_features] * n_mbis_layers + [1], - activation=[readout_activation] * n_mbis_layers + ["Identity"], - ), - ) - - super().__init__( - convolution_module=ConvolutionModule( - architecture="SAGEConv", - in_feats=n_atom_features, - hidden_feats=[n_gcn_hidden_features] * n_gcn_layers, - activation=[gcn_activation] * n_gcn_layers - if gcn_activation is not None - else None, - ), - readout_modules=readout, - learning_rate=learning_rate, + def compute_properties(self, molecule: Chem.Mol) -> dict[str, torch.Tensor]: + dgl_molecule = DGLMolecule.from_rdkit( + molecule, self.config.model.atom_features, self.config.model.bond_features ) - self.save_hyperparameters() - - def compute_properties(self, molecule: Molecule) -> Dict[str, torch.Tensor]: - atom_features, bond_features = self.features() - dgl_molecule = DGLMolecule.from_openff(molecule, atom_features, bond_features) - return self.forward(dgl_molecule) diff --git a/naglmbis/models/models.py b/naglmbis/models/models.py index 0b54afc..deb9c67 100644 --- a/naglmbis/models/models.py +++ b/naglmbis/models/models.py @@ -1,74 +1,19 @@ from typing import Literal -from nagl.features import ( - AtomConnectivity, - AtomFormalCharge, - AtomicElement, - AtomIsAromatic, -) +import torch -from naglmbis.features.atom import ( - AtomicMass, - AtomInRingOfSize, - ExplicitValence, - Hybridization, - TotalDegree, - TotalValence, -) from naglmbis.models.base_model import MBISGraphModel from naglmbis.utils import get_model_weights - -class MBISGraphModelV1(MBISGraphModel): - """ - The first version of the model with the basic set of features. - """ - - def features(self): - atom_features = [ - AtomicElement(["H", "C", "N", "O", "F", "Cl", "Br", "S", "P"]), - AtomConnectivity(), - AtomInRingOfSize(3), - AtomInRingOfSize(4), - AtomInRingOfSize(5), - AtomInRingOfSize(6), - ] - bond_features = [] - return atom_features, bond_features - - -class EspalomaModel(MBISGraphModel): - """Try and recreate the espaloma model""" - - def features(self): - atom_features = [ - AtomicElement(["H", "C", "N", "O", "F", "Cl", "Br", "S", "P"]), - TotalDegree(), - TotalValence(), - ExplicitValence(), - AtomFormalCharge(), - AtomIsAromatic(), - AtomicMass(), - AtomInRingOfSize(3), - AtomInRingOfSize(4), - AtomInRingOfSize(5), - AtomInRingOfSize(6), - AtomInRingOfSize(7), - AtomInRingOfSize(8), - Hybridization(), - ] - bond_features = [] - return atom_features, bond_features - - charge_weights = { - "nagl-v1": {"path": "mbis_charges_v1.ckpt", "model": MBISGraphModelV1} -} -volume_weights = { - "nagl-v1": {"path": "mbis_volumes_v1.ckpt", "model": MBISGraphModelV1} + "nagl-v1-mbis": {"checkpoint_path": "nagl-v1-mbis.ckpt"}, + "nagl-v1-mbis-dipole": {"checkpoint_path": "nagl-v1-mbis-dipole.ckpt"}, } -CHARGE_MODELS = Literal["nagl-v1"] -VOLUME_MODELS = Literal["nagl-v1"] +# volume_weights = { +# "nagl-v1": {"path": "mbis_volumes_v1.ckpt", "model": MBISGraphModel} +# } +CHARGE_MODELS = Literal["nagl-v1-mbis-dipole", "nagl-v1-mbis"] +# VOLUME_MODELS = Literal["nagl-v1"] def load_charge_model(charge_model: CHARGE_MODELS) -> MBISGraphModel: @@ -76,16 +21,20 @@ def load_charge_model(charge_model: CHARGE_MODELS) -> MBISGraphModel: Load up one of the predefined charge models, this will load the weights and parameter settings. """ weight_path = get_model_weights( - model_type="charge", model_name=charge_weights[charge_model]["path"] - ) - return charge_weights[charge_model]["model"].load_from_checkpoint(weight_path) - - -def load_volume_model(volume_model: VOLUME_MODELS) -> MBISGraphModel: - """ - Load one of the predefined volume models, this will load the weights and parameter settings. - """ - weight_path = get_model_weights( - model_type="volume", model_name=volume_weights[volume_model]["path"] + model_type="charge", model_name=charge_weights[charge_model]["checkpoint_path"] ) - return volume_weights[volume_model]["model"].load_from_checkpoint(weight_path) + model_data = torch.load(weight_path) + model = MBISGraphModel(**model_data["hyper_parameters"]) + model.load_state_dict(model_data["state_dict"]) + model.eval() + return model + + +# def load_volume_model(volume_model: VOLUME_MODELS) -> MBISGraphModel: +# """ +# Load one of the predefined volume models, this will load the weights and parameter settings. +# """ +# weight_path = get_model_weights( +# model_type="volume", model_name=volume_weights[volume_model]["path"] +# ) +# return volume_weights[volume_model]["model"].load_from_checkpoint(weight_path) diff --git a/naglmbis/plugins/__init__.py b/naglmbis/plugins/__init__.py index b72d698..c8a5dd8 100644 --- a/naglmbis/plugins/__init__.py +++ b/naglmbis/plugins/__init__.py @@ -1,2 +1,4 @@ from naglmbis.plugins.plugins import NAGLMBISHandler from naglmbis.plugins.utils import modify_force_field + +__all__ = [NAGLMBISHandler, modify_force_field] diff --git a/naglmbis/plugins/bccs.py b/naglmbis/plugins/bccs.py index 4bdaa79..fd401b7 100644 --- a/naglmbis/plugins/bccs.py +++ b/naglmbis/plugins/bccs.py @@ -1,6 +1,6 @@ # a file to track bcc models -from typing_extensions import Literal from openff.toolkit.typing.engines.smirnoff import ForceField +from typing_extensions import Literal # Model fit with nagl-v1 charges and nagl-v1 volumes with no polar h Rfree # list of smirks and charge corrections diff --git a/naglmbis/plugins/plugins.py b/naglmbis/plugins/plugins.py index fec8110..42dac1b 100644 --- a/naglmbis/plugins/plugins.py +++ b/naglmbis/plugins/plugins.py @@ -9,13 +9,13 @@ _allow_only, _NonbondedHandler, ) +from openmm import unit from qubekit.charges import MBISCharges from qubekit.molecules import Ligand from naglmbis.models import load_charge_model, load_volume_model +from naglmbis.plugins.bccs import bcc_force_fields, load_bcc_model from naglmbis.plugins.trained_models import trained_models -from naglmbis.plugins.bccs import load_bcc_model, bcc_force_fields -from openmm import unit class NAGLMBISHandler(_NonbondedHandler): diff --git a/naglmbis/tests/conftest.py b/naglmbis/tests/conftest.py index 748ea82..3d895bc 100644 --- a/naglmbis/tests/conftest.py +++ b/naglmbis/tests/conftest.py @@ -9,7 +9,7 @@ def methanol(): """ methanol = Molecule.from_mapped_smiles("[H:3][C:1]([H:4])([H:5])[O:2][H:6]") methanol.generate_conformers(n_conformers=1) - return methanol + return methanol.to_rdkit() @pytest.fixture() @@ -17,7 +17,7 @@ def water(): """Make an OpenFF molecule of water""" water = Molecule.from_mapped_smiles("[H:2][O:1][H:3]") water.generate_conformers(n_conformers=1) - return water + return water.to_rdkit() @pytest.fixture() @@ -25,10 +25,10 @@ def iodobezene(): """Make an OpenFF molecule of iodobenzene""" i_ben = Molecule.from_smiles("c1ccc(cc1)I") i_ben.generate_conformers(n_conformers=1) - return i_ben + return i_ben.to_rdkit() @pytest.fixture() def methane_no_conf(): """Make an OpenFF molecule of methane with no conformer""" - return Molecule.from_smiles("C") + return Molecule.from_smiles("C").to_rdkit() diff --git a/naglmbis/tests/test_atom_features.py b/naglmbis/tests/test_atom_features.py index ffd5e2f..7fa1aa7 100644 --- a/naglmbis/tests/test_atom_features.py +++ b/naglmbis/tests/test_atom_features.py @@ -1,7 +1,6 @@ import numpy as np -from nagl.features import AtomConnectivity -from naglmbis.features.atom import ( +from naglmbis.features import ( AtomicMass, AtomicPolarisability, ExplicitValence, @@ -13,7 +12,7 @@ SandersonElectronegativity, TotalDegree, TotalValence, - vdWRadius, + VDWRadius, ) @@ -74,7 +73,7 @@ def test_sanderson(methanol): def test_vdw_radii(methanol): """Make sure the vdw radii is correctly assigned""" - radii = vdWRadius() + radii = VDWRadius() assert len(radii) == 1 feats = radii(methanol).numpy() assert feats.shape == (6, 1) diff --git a/naglmbis/tests/test_models.py b/naglmbis/tests/test_models.py index cf0668f..9b7c448 100644 --- a/naglmbis/tests/test_models.py +++ b/naglmbis/tests/test_models.py @@ -1,48 +1,25 @@ import torch -from naglmbis.models import MBISGraphModelV1, load_charge_model, load_volume_model +from naglmbis.models import load_charge_model -def test_charge_model_v1(methanol): +def test_charge_model_v1_dipoles(methanol): """ - Test loading the charge model and computing the MBIS charges. + Test loading the charge model and computing the MBIS charges with a model co-trained to dipoles. """ - charge_model = load_charge_model(charge_model="nagl-v1") + charge_model = load_charge_model(charge_model="nagl-v1-mbis-dipole") charges = charge_model.compute_properties(molecule=methanol)[ "mbis-charges" ].detach() - ref = torch.Tensor([[0.0847], [-0.6714], [0.0494], [0.0494], [0.0494], [0.4384]]) + ref = torch.Tensor([[0.0618], [-0.6490], [0.0509], [0.0509], [0.0509], [0.4347]]) assert torch.allclose(charges, ref, atol=1e-4) -def test_volume_model_v1(methanol): - """ - Test loading the volume model and computing the MBIS volumes - """ - volume_model = load_volume_model(volume_model="nagl-v1") - volumes = volume_model.compute_properties(molecule=methanol)[ - "mbis-volumes" +def test_charge_model_v1_mbis(methanol): + """Test computing the charges with the model trained to only mbis charges.""" + charge_model = load_charge_model(charge_model="nagl-v1-mbis") + charges = charge_model.compute_properties(molecule=methanol)[ + "mbis-charges" ].detach() - ref = torch.Tensor([[29.6985], [25.3335], [3.0224], [3.0224], [3.0224], [1.0341]]) - assert torch.allclose(volumes, ref, atol=1e-4) - - -def test_activation(): - """ - Test building a model with a new activation function - """ - model = MBISGraphModelV1( - n_gcn_hidden_features=128, - n_gcn_layers=3, - n_mbis_hidden_features=128, - n_mbis_layers=2, - readout_modules="charge", - learning_rate=1e-4, - gcn_activation=None, - readout_activation="Sigmoid", - ) - - assert model.readout_activation == "Sigmoid" - assert isinstance( - model.readout_modules["mbis-charges"].readout_layers[1], torch.nn.Sigmoid - ) + ref = torch.Tensor([[0.0835], [-0.6821], [0.0491], [0.0491], [0.0491], [0.4515]]) + assert torch.allclose(charges, ref, atol=1e-4) diff --git a/naglmbis/tests/test_plugins.py b/naglmbis/tests/test_plugins.py index 0cb49b9..ed4ca63 100644 --- a/naglmbis/tests/test_plugins.py +++ b/naglmbis/tests/test_plugins.py @@ -1,154 +1,154 @@ -import pytest -from openmm import unit - -from naglmbis.plugins import modify_force_field - - -def test_modify_force_field(): - """Make sure we can correctly modify a force field with a NAGLMBIS tag, this ensures the plugin is picked - up by the toolkit. - """ - nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") - handlers = nagl_sage.registered_parameter_handlers - assert "NAGLMBIS" in handlers - assert "ToolkitAM1BCC" not in handlers - - -def test_plugin_water(water): - """Make sure that the default TIP3P parameters are applied to water when using a NAGLMBIS force field.""" - - nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") - water_system = nagl_sage.create_openmm_system(topology=water.to_topology()) - water_forces = { - water_system.getForce(index).__class__.__name__: water_system.getForce(index) - for index in range(water_system.getNumForces()) - } - # get the oxygen parameters, these should match the library charge - charge, sigma, epsilon = water_forces["NonbondedForce"].getParticleParameters(0) - assert charge / unit.elementary_charge == -0.834 - assert sigma / unit.nanometer == 0.31507 - assert epsilon / unit.kilojoule_per_mole == 0.6363864000000001 - # now check the hydrogen - for i in [1, 2]: - charge, sigma, epsilon = water_forces["NonbondedForce"].getParticleParameters(i) - assert charge / unit.elementary_charge == 0.417 - assert sigma / unit.nanometer == 0.1 - assert epsilon / unit.kilojoule_per_mole == 0.0 - - -def test_plugin_methanol(methanol): - """Make sure the correct parameters are asigned to methanol when using a NAGLMBIS force field.""" - - nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") - methanol_system = nagl_sage.create_openmm_system(topology=methanol.to_topology()) - methanol_forces = { - methanol_system.getForce(index).__class__.__name__: methanol_system.getForce( - index - ) - for index in range(methanol_system.getNumForces()) - } - # check the system parameters - ref_parameters = { - # index: [charge, sigma, epsilon] - 0: [0.08475, 0.3506905398376649, 0.29246740950730743], - 1: [-0.67143, 0.30824716826324094, 0.42874612945325397], - 2: [0.049413, 0.23126852234757847, 0.07259909774155628], - 3: [0.049413, 0.23126852234757847, 0.07259909774155628], - 4: [0.049413, 0.23126852234757847, 0.07259909774155628], - 5: [0.438441, 0.11098246898497655, 0.41631660852994784], - } - for particle_index, refs in ref_parameters.items(): - charge, sigma, epsilon = methanol_forces[ - "NonbondedForce" - ].getParticleParameters(particle_index) - assert charge / unit.elementary_charge == pytest.approx(refs[0], abs=1e-5) - assert sigma / unit.nanometers == pytest.approx(refs[1]) - assert epsilon / unit.kilojoule_per_mole == pytest.approx(refs[2]) - - -def test_plugin_missing_element(iodobezene): - """Make sure an error is raised when we try to parameterize a molecule with an element not covered by model 1.""" - from qubekit.utils.exceptions import MissingRfreeError - - nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") - with pytest.raises(MissingRfreeError): - _ = nagl_sage.create_openmm_system(topology=iodobezene.to_topology()) - - -def test_plugin_no_conformer(methane_no_conf): - """Make sure the system can still be made if the refernce molecule has no conformer""" - - nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") - _ = nagl_sage.create_openmm_system(topology=methane_no_conf.to_topology()) - - -def test_espaloma_charge(methanol, tmpdir): - """ - Test using the esploma charge module to compute the charges for a model. - """ - - with tmpdir.as_cwd(): - nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") - nagl_handler = nagl_sage.get_parameter_handler("NAGLMBIS") - nagl_handler.charge_model = "espaloma-v1" - methanol_system = nagl_sage.create_openmm_system( - topology=methanol.to_topology() - ) - methanol_forces = { - methanol_system.getForce( - index - ).__class__.__name__: methanol_system.getForce(index) - for index in range(methanol_system.getNumForces()) - } - # check the system parameters - ref_parameters = { - # index: [charge, sigma, epsilon] - 0: [0.114912, 0.3506905398376649, 0.29246740950730743], - 1: [-0.604643, 0.30824716826324094, 0.42874612945325397], - 2: [0.030345, 0.23126852234757847, 0.07259909774155628], - 3: [0.030345, 0.23126852234757847, 0.07259909774155628], - 4: [0.030345, 0.23126852234757847, 0.07259909774155628], - 5: [0.398696, 0.11098246898497655, 0.41631660852994784], - } - for particle_index, refs in ref_parameters.items(): - charge, sigma, epsilon = methanol_forces[ - "NonbondedForce" - ].getParticleParameters(particle_index) - assert charge / unit.elementary_charge == pytest.approx(refs[0], abs=1e-5) - assert sigma / unit.nanometers == pytest.approx(refs[1]) - assert epsilon / unit.kilojoule_per_mole == pytest.approx(refs[2]) - - -def test_bcc_charges(methanol): - """ - Make sure bccs are correctly apply with NAGL charges when requested. - - This is the same as the test_plugin_methanol test but the carbon and oxygen charges are slightly different - """ - nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") - nagl_handler = nagl_sage.get_parameter_handler("NAGLMBIS") - nagl_handler.bcc_model = "nagl-v1" - methanol_system = nagl_sage.create_openmm_system(topology=methanol.to_topology()) - methanol_forces = { - methanol_system.getForce(index).__class__.__name__: methanol_system.getForce( - index - ) - for index in range(methanol_system.getNumForces()) - } - # check the system parameters - ref_parameters = { - # index: [charge, sigma, epsilon] - 0: [0.041475, 0.3506905398376649, 0.29246740950730743], - 1: [-0.628155, 0.30824716826324094, 0.42874612945325397], - 2: [0.049413, 0.23126852234757847, 0.07259909774155628], - 3: [0.049413, 0.23126852234757847, 0.07259909774155628], - 4: [0.049413, 0.23126852234757847, 0.07259909774155628], - 5: [0.438441, 0.11098246898497655, 0.41631660852994784], - } - for particle_index, refs in ref_parameters.items(): - charge, sigma, epsilon = methanol_forces[ - "NonbondedForce" - ].getParticleParameters(particle_index) - assert charge / unit.elementary_charge == pytest.approx(refs[0], abs=1e-5) - assert sigma / unit.nanometers == pytest.approx(refs[1]) - assert epsilon / unit.kilojoule_per_mole == pytest.approx(refs[2]) +# import pytest +# from openmm import unit +# +# from naglmbis.plugins import modify_force_field + + +# def test_modify_force_field(): +# """Make sure we can correctly modify a force field with a NAGLMBIS tag, this ensures the plugin is picked +# up by the toolkit. +# """ +# nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") +# handlers = nagl_sage.registered_parameter_handlers +# assert "NAGLMBIS" in handlers +# assert "ToolkitAM1BCC" not in handlers + + +# def test_plugin_water(water): +# """Make sure that the default TIP3P parameters are applied to water when using a NAGLMBIS force field.""" +# +# nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") +# water_system = nagl_sage.create_openmm_system(topology=water.to_topology()) +# water_forces = { +# water_system.getForce(index).__class__.__name__: water_system.getForce(index) +# for index in range(water_system.getNumForces()) +# } +# # get the oxygen parameters, these should match the library charge +# charge, sigma, epsilon = water_forces["NonbondedForce"].getParticleParameters(0) +# assert charge / unit.elementary_charge == -0.834 +# assert sigma / unit.nanometer == 0.31507 +# assert epsilon / unit.kilojoule_per_mole == 0.6363864000000001 +# # now check the hydrogen +# for i in [1, 2]: +# charge, sigma, epsilon = water_forces["NonbondedForce"].getParticleParameters(i) +# assert charge / unit.elementary_charge == 0.417 +# assert sigma / unit.nanometer == 0.1 +# assert epsilon / unit.kilojoule_per_mole == 0.0 + + +# def test_plugin_methanol(methanol): +# """Make sure the correct parameters are asigned to methanol when using a NAGLMBIS force field.""" +# +# nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") +# methanol_system = nagl_sage.create_openmm_system(topology=methanol.to_topology()) +# methanol_forces = { +# methanol_system.getForce(index).__class__.__name__: methanol_system.getForce( +# index +# ) +# for index in range(methanol_system.getNumForces()) +# } +# # check the system parameters +# ref_parameters = { +# # index: [charge, sigma, epsilon] +# 0: [0.08475, 0.3506905398376649, 0.29246740950730743], +# 1: [-0.67143, 0.30824716826324094, 0.42874612945325397], +# 2: [0.049413, 0.23126852234757847, 0.07259909774155628], +# 3: [0.049413, 0.23126852234757847, 0.07259909774155628], +# 4: [0.049413, 0.23126852234757847, 0.07259909774155628], +# 5: [0.438441, 0.11098246898497655, 0.41631660852994784], +# } +# for particle_index, refs in ref_parameters.items(): +# charge, sigma, epsilon = methanol_forces[ +# "NonbondedForce" +# ].getParticleParameters(particle_index) +# assert charge / unit.elementary_charge == pytest.approx(refs[0], abs=1e-5) +# assert sigma / unit.nanometers == pytest.approx(refs[1]) +# assert epsilon / unit.kilojoule_per_mole == pytest.approx(refs[2]) + + +# def test_plugin_missing_element(iodobezene): +# """Make sure an error is raised when we try to parameterize a molecule with an element not covered by model 1.""" +# from qubekit.utils.exceptions import MissingRfreeError +# +# nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") +# with pytest.raises(MissingRfreeError): +# _ = nagl_sage.create_openmm_system(topology=iodobezene.to_topology()) + + +# def test_plugin_no_conformer(methane_no_conf): +# """Make sure the system can still be made if the refernce molecule has no conformer""" +# +# nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") +# _ = nagl_sage.create_openmm_system(topology=methane_no_conf.to_topology()) + + +# def test_espaloma_charge(methanol, tmpdir): +# """ +# Test using the esploma charge module to compute the charges for a model. +# """ +# +# with tmpdir.as_cwd(): +# nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") +# nagl_handler = nagl_sage.get_parameter_handler("NAGLMBIS") +# nagl_handler.charge_model = "espaloma-v1" +# methanol_system = nagl_sage.create_openmm_system( +# topology=methanol.to_topology() +# ) +# methanol_forces = { +# methanol_system.getForce( +# index +# ).__class__.__name__: methanol_system.getForce(index) +# for index in range(methanol_system.getNumForces()) +# } +# # check the system parameters +# ref_parameters = { +# # index: [charge, sigma, epsilon] +# 0: [0.114912, 0.3506905398376649, 0.29246740950730743], +# 1: [-0.604643, 0.30824716826324094, 0.42874612945325397], +# 2: [0.030345, 0.23126852234757847, 0.07259909774155628], +# 3: [0.030345, 0.23126852234757847, 0.07259909774155628], +# 4: [0.030345, 0.23126852234757847, 0.07259909774155628], +# 5: [0.398696, 0.11098246898497655, 0.41631660852994784], +# } +# for particle_index, refs in ref_parameters.items(): +# charge, sigma, epsilon = methanol_forces[ +# "NonbondedForce" +# ].getParticleParameters(particle_index) +# assert charge / unit.elementary_charge == pytest.approx(refs[0], abs=1e-5) +# assert sigma / unit.nanometers == pytest.approx(refs[1]) +# assert epsilon / unit.kilojoule_per_mole == pytest.approx(refs[2]) + + +# def test_bcc_charges(methanol): +# """ +# Make sure bccs are correctly apply with NAGL charges when requested. +# +# This is the same as the test_plugin_methanol test but the carbon and oxygen charges are slightly different +# """ +# nagl_sage = modify_force_field(force_field="openff_unconstrained-2.0.0.offxml") +# nagl_handler = nagl_sage.get_parameter_handler("NAGLMBIS") +# nagl_handler.bcc_model = "nagl-v1" +# methanol_system = nagl_sage.create_openmm_system(topology=methanol.to_topology()) +# methanol_forces = { +# methanol_system.getForce(index).__class__.__name__: methanol_system.getForce( +# index +# ) +# for index in range(methanol_system.getNumForces()) +# } +# # check the system parameters +# ref_parameters = { +# # index: [charge, sigma, epsilon] +# 0: [0.041475, 0.3506905398376649, 0.29246740950730743], +# 1: [-0.628155, 0.30824716826324094, 0.42874612945325397], +# 2: [0.049413, 0.23126852234757847, 0.07259909774155628], +# 3: [0.049413, 0.23126852234757847, 0.07259909774155628], +# 4: [0.049413, 0.23126852234757847, 0.07259909774155628], +# 5: [0.438441, 0.11098246898497655, 0.41631660852994784], +# } +# for particle_index, refs in ref_parameters.items(): +# charge, sigma, epsilon = methanol_forces[ +# "NonbondedForce" +# ].getParticleParameters(particle_index) +# assert charge / unit.elementary_charge == pytest.approx(refs[0], abs=1e-5) +# assert sigma / unit.nanometers == pytest.approx(refs[1]) +# assert epsilon / unit.kilojoule_per_mole == pytest.approx(refs[2]) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..612d349 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel", "versioneer"] +build-backend = "setuptools.build_meta" + +[project] +name = "naglmbis" +description = "A collection of models to predict conformation independent MBIS charges and volumes of molecules, built on the [NAGL](https://github.com/SimonBoothroyd/nagl) package by SimonBoothroyd." +authors = [ {name = "Joshua Horton"} ] +license = { text = "MIT" } +dynamic = ["version"] +readme = "README.md" +requires-python = ">=3.10" +classifiers = ["Programming Language :: Python :: 3"] + +#[project.entry-points."openff.toolkit.plugins.handlers"] +#NAGLMBIS = "naglmbis.plugins:NAGLMBISHandler" + +[tool.setuptools] +zip-safe = false +include-package-data = true + +[tool.setuptools.dynamic] +version = {attr = "naglmbis.__version__"} + +[tool.setuptools.packages.find] +namespaces = true +where = ["."] + +[tool.versioneer] +# Automatic version numbering scheme +VCS = "git" +style = "pep440" +versionfile_source = "naglmbis/_version.py" +versionfile_build = "naglmbis/_version.py" +tag_prefix = 'naglmbis-' + +[tool.black] +line-length = 88 + +[tool.isort] +profile = "black" + +[tool.flake8] +max-line-length = 88 +ignore = ["E203", "E266", "E501", "W503"] +select = ["B","C","E","F","W","T4","B9"] + +[tool.coverage.run] +omit = ["**/tests/*", "**/_version.py"] + +[tool.coverage.report] +exclude_lines = [ + "@overload", + "pragma: no cover", + "raise NotImplementedError", + "if __name__ = .__main__.:", + "if TYPE_CHECKING:", + "if typing.TYPE_CHECKING:", +] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 26fa2ac..0000000 --- a/setup.cfg +++ /dev/null @@ -1,38 +0,0 @@ -# Helper file to handle all configs - -[coverage:run] -# .coveragerc to control coverage.py and pytest-cov -omit = - # Omit the tests - */tests/* - # Omit generated versioneer - naglmbis/_version.py - -[flake8] -# Flake8, PyFlakes, etc -max-line-length = 88 -ignore = E203, E266, E501, W503 -select = B,C,E,F,W,T4,B9 - -[isort] -multi_line_output=3 -include_trailing_comma=True -force_grid_wrap=0 -use_parentheses=True -line_length=88 - -[versioneer] -# Automatic version numbering scheme -VCS = git -style = pep440 -versionfile_source = naglmbis/_version.py -versionfile_build = naglmbis/_version.py -tag_prefix = '' - -[aliases] -test = pytest - -[tool:pytest] -filterwarnings = - ignore::DeprecationWarning - ignore::PendingDeprecationWarning \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index e992abb..0000000 --- a/setup.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -naglmbis -Models made with NAGL to predict mbis properties. -""" -import sys -from setuptools import setup, find_packages -import versioneer - -short_description = __doc__.split("\n") - -# from https://github.com/pytest-dev/pytest-runner#conditional-requirement -needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) -pytest_runner = ['pytest-runner'] if needs_pytest else [] - -try: - with open("README.md", "r") as handle: - long_description = handle.read() -except IOError: - long_description = "\n".join(short_description[2:]) - - -setup( - # Self-descriptive entries which should always be present - name='naglmbis', - author='Joshua Horton', - author_email='Josh.Horton@newcastle.ac.uk', - description=short_description[0], - long_description=long_description, - long_description_content_type="text/markdown", - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - license='MIT', - - # Which Python importable modules should be included when your package is installed - # Handled automatically by setuptools. Use 'exclude' to prevent some specific - # subpackage(s) from being added, if needed - packages=find_packages(), - - # Optional include package data to ship with your package - # Customize MANIFEST.in if the general case does not suit your needs - # Comment out this line to prevent the files from being packaged with your software - include_package_data=True, - - # Allows `setup.py test` to work correctly with pytest - setup_requires=[] + pytest_runner, - entry_points={ - "openff.toolkit.plugins.handlers": [ - "NAGLMBIS = naglmbis.plugins:NAGLMBISHandler" - ] - }, - # Set up the main CLI entry points -) diff --git a/versioneer.py b/versioneer.py index 88c91fb..1e3753e 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,5 +1,5 @@ -# Version: 0.18 +# Version: 0.29 """The Versioneer - like a rocketeer, but for versions. @@ -7,18 +7,14 @@ ============== * like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer +* https://github.com/python-versioneer/python-versioneer * Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based +* License: Public Domain (Unlicense) +* Compatible with: Python 3.7, 3.8, 3.9, 3.10, 3.11 and pypy3 +* [![Latest Version][pypi-image]][pypi-url] +* [![Build Status][travis-image]][travis-url] + +This is a tool for managing a recorded version number in setuptools-based python projects. The goal is to remove the tedious and error-prone "update the embedded version string" step from your release process. Making a new release should be as easy as recording a new tag in your version-control @@ -27,9 +23,38 @@ ## Quick Install -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results +Versioneer provides two installation modes. The "classic" vendored mode installs +a copy of versioneer into your repository. The experimental build-time dependency mode +is intended to allow you to skip this step and simplify the process of upgrading. + +### Vendored mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) + * Note that you will need to add `tomli; python_version < "3.11"` to your + build-time dependencies if you use `pyproject.toml` +* run `versioneer install --vendor` in your source tree, commit the results +* verify version information with `python setup.py version` + +### Build-time dependency mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) +* add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`) + to the `requires` key of the `build-system` table in `pyproject.toml`: + ```toml + [build-system] + requires = ["setuptools", "versioneer[toml]"] + build-backend = "setuptools.build_meta" + ``` +* run `versioneer install --no-vendor` in your source tree, commit the results +* verify version information with `python setup.py version` ## Version Identifiers @@ -61,7 +86,7 @@ for example `git describe --tags --dirty --always` reports things like "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. +uncommitted changes). The version identifier is used for multiple purposes: @@ -166,7 +191,7 @@ Some situations are known to cause problems for Versioneer. This details the most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). +[issues page](https://github.com/python-versioneer/python-versioneer/issues). ### Subprojects @@ -180,7 +205,7 @@ `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI distributions (and upload multiple independently-installable tarballs). * Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. + provide bindings to Python (and perhaps other languages) in subdirectories. Versioneer will look for `.git` in parent directories, and most operations should get the right version string. However `pip` and `setuptools` have bugs @@ -194,9 +219,9 @@ Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in some later version. -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking +[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the +[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the issue from the Versioneer side in more detail. [pip PR#3176](https://github.com/pypa/pip/pull/3176) and [pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve @@ -224,31 +249,20 @@ cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into a different virtualenv), so this can be surprising. -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes +[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes this one, but upgrading to a newer version of setuptools should probably resolve it. -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - ## Updating Versioneer To upgrade your project to a new release of Versioneer, do the following: * install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace +* edit `setup.cfg` and `pyproject.toml`, if necessary, + to include any new configuration settings indicated by the release notes. + See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install --[no-]vendor` in your source tree, to replace `SRC/_version.py` * commit any changed files @@ -265,35 +279,70 @@ direction and include code from all supported VCS systems, reducing the number of intermediate scripts. +## Similar projects + +* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time + dependency +* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of + versioneer +* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools + plugin ## License To make Versioneer easier to embed, all its code is dedicated to the public domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . +Specifically, both are released under the "Unlicense", as described in +https://unlicense.org/. + +[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg +[pypi-url]: https://pypi.python.org/pypi/versioneer/ +[travis-image]: +https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg +[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer """ +# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring +# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements +# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error +# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with +# pylint:disable=attribute-defined-outside-init,too-many-arguments -from __future__ import print_function -try: - import configparser -except ImportError: - import ConfigParser as configparser +import configparser import errno import json import os import re import subprocess import sys +from pathlib import Path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import NoReturn +import functools + +have_tomllib = True +if sys.version_info >= (3, 11): + import tomllib +else: + try: + import tomli as tomllib + except ImportError: + have_tomllib = False class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + versionfile_source: str + versionfile_build: Optional[str] + parentdir_prefix: Optional[str] + verbose: Optional[bool] + -def get_root(): +def get_root() -> str: """Get the project root directory. We require that all commands are run from the project root, i.e. the @@ -301,13 +350,23 @@ def get_root(): """ root = os.path.realpath(os.path.abspath(os.getcwd())) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): # allow 'python path/to/setup.py COMMAND' root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): err = ("Versioneer was unable to run the project root directory. " "Versioneer requires setup.py to be executed from " "its immediate directory (like 'python setup.py COMMAND'), " @@ -321,43 +380,62 @@ def get_root(): # module-import table will cache the first one. So we can't use # os.path.dirname(__file__), as that will find whichever # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) + my_path = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: + if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) + % (os.path.dirname(my_path), versioneer_py)) except NameError: pass return root -def get_config_from_root(root): +def get_config_from_root(root: str) -> VersioneerConfig: """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or + # This might raise OSError (if setup.cfg is missing), or # configparser.NoSectionError (if it lacks a [versioneer] section), or # configparser.NoOptionError (if it lacks "VCS="). See the docstring at # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None + root_pth = Path(root) + pyproject_toml = root_pth / "pyproject.toml" + setup_cfg = root_pth / "setup.cfg" + section: Union[Dict[str, Any], configparser.SectionProxy, None] = None + if pyproject_toml.exists() and have_tomllib: + try: + with open(pyproject_toml, 'rb') as fobj: + pp = tomllib.load(fobj) + section = pp['tool']['versioneer'] + except (tomllib.TOMLDecodeError, KeyError) as e: + print(f"Failed to load config from {pyproject_toml}: {e}") + print("Try to load it from setup.cfg") + if not section: + parser = configparser.ConfigParser() + with open(setup_cfg) as cfg_file: + parser.read_file(cfg_file) + parser.get("versioneer", "VCS") # raise error if missing + + section = parser["versioneer"] + + # `cast`` really shouldn't be used, but its simplest for the + # common VersioneerConfig users at the moment. We verify against + # `None` values elsewhere where it matters + cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): + cfg.VCS = section['VCS'] + cfg.style = section.get("style", "") + cfg.versionfile_source = cast(str, section.get("versionfile_source")) + cfg.versionfile_build = section.get("versionfile_build") + cfg.tag_prefix = cast(str, section.get("tag_prefix")) + if cfg.tag_prefix in ("''", '""', None): cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") + cfg.parentdir_prefix = section.get("parentdir_prefix") + if isinstance(section, configparser.SectionProxy): + # Make sure configparser translates to bool + cfg.verbose = section.getboolean("verbose") + else: + cfg.verbose = section.get("verbose") + return cfg @@ -366,37 +444,48 @@ class NotThisMethod(Exception): # these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f + HANDLERS.setdefault(vcs, {})[method] = f return f return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -407,26 +496,25 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -LONG_VERSION_PY['git'] = ''' +LONG_VERSION_PY['git'] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -435,9 +523,11 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -453,8 +543,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool + -def get_config(): +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -472,13 +569,13 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -487,22 +584,35 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -513,18 +623,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, if verbose: print("unable to find command, tried %%s" %% (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %%s (error)" %% dispcmd) print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -533,15 +645,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print("Tried directories %%s but none started with prefix %%s" %% @@ -550,41 +661,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -597,11 +715,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %%d @@ -610,7 +728,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%%s', no digits" %% ",".join(refs - tags)) if verbose: @@ -619,6 +737,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %%s" %% r) return {"version": r, @@ -634,7 +757,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -645,8 +773,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %%s not under git control" %% root) @@ -654,24 +789,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -688,7 +856,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%%s'" %% describe_out) return pieces @@ -713,26 +881,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() + date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -757,23 +926,71 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%%d" %% (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] + rendered = "0.post0.dev%%d" %% pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -800,12 +1017,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -822,7 +1068,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -842,7 +1088,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -862,7 +1108,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -876,10 +1122,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -894,7 +1144,7 @@ def render(pieces, style): "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -915,7 +1165,7 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: return {"version": "0+unknown", "full-revisionid": None, @@ -942,41 +1192,48 @@ def get_versions(): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -989,11 +1246,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1002,7 +1259,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1011,6 +1268,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) return {"version": r, @@ -1026,7 +1288,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -1037,8 +1304,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1046,24 +1320,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -1080,7 +1387,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%s'" % describe_out) return pieces @@ -1105,19 +1412,20 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def do_vcs_install(manifest_in, versionfile_source, ipy): +def do_vcs_install(versionfile_source: str, ipy: Optional[str]) -> None: """Git-specific installation logic for Versioneer. For Git, this means creating/changing .gitattributes to mark _version.py @@ -1126,36 +1434,40 @@ def do_vcs_install(manifest_in, versionfile_source, ipy): GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] + files = [versionfile_source] if ipy: files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) + if "VERSIONEER_PEP518" not in globals(): + try: + my_path = __file__ + if my_path.endswith((".pyc", ".pyo")): + my_path = os.path.splitext(my_path)[0] + ".py" + versioneer_file = os.path.relpath(my_path) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) present = False try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: + with open(".gitattributes", "r") as fobj: + for line in fobj: + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + break + except OSError: pass if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() + with open(".gitattributes", "a+") as fobj: + fobj.write(f"{versionfile_source} export-subst\n") files.append(".gitattributes") run_command(GITS, ["add", "--"] + files) -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -1164,15 +1476,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print("Tried directories %s but none started with prefix %s" % @@ -1181,7 +1492,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from +# This file was generated by 'versioneer.py' (0.29) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. @@ -1198,12 +1509,12 @@ def get_versions(): """ -def versions_from_file(filename): +def versions_from_file(filename: str) -> Dict[str, Any]: """Try to determine the version from _version.py if present.""" try: with open(filename) as f: contents = f.read() - except EnvironmentError: + except OSError: raise NotThisMethod("unable to read _version.py") mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) @@ -1215,9 +1526,8 @@ def versions_from_file(filename): return json.loads(mo.group(1)) -def write_to_version_file(filename, versions): +def write_to_version_file(filename: str, versions: Dict[str, Any]) -> None: """Write the given version number to the given _version.py file.""" - os.unlink(filename) contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: @@ -1226,14 +1536,14 @@ def write_to_version_file(filename, versions): print("set %s to '%s'" % (filename, versions["version"])) -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -1258,23 +1568,71 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -1301,12 +1659,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -1323,7 +1710,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -1343,7 +1730,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -1363,7 +1750,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -1377,10 +1764,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -1399,7 +1790,7 @@ class VersioneerBadRootError(Exception): """The project root directory is unknown or missing key files.""" -def get_versions(verbose=False): +def get_versions(verbose: bool = False) -> Dict[str, Any]: """Get the project version from whatever source is available. Returns dict with two keys: 'version' and 'full'. @@ -1414,7 +1805,7 @@ def get_versions(verbose=False): assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose + verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` assert cfg.versionfile_source is not None, \ "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" @@ -1475,13 +1866,17 @@ def get_versions(verbose=False): "date": None} -def get_version(): +def get_version() -> str: """Get the short version string for this project.""" return get_versions()["version"] -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" +def get_cmdclass(cmdclass: Optional[Dict[str, Any]] = None): + """Get the custom setuptools subclasses used by Versioneer. + + If the package uses a different cmdclass (e.g. one from numpy), it + should be provide as an argument. + """ if "versioneer" in sys.modules: del sys.modules["versioneer"] # this fixes the "python setup.py develop" case (also 'install' and @@ -1495,25 +1890,25 @@ def get_cmdclass(): # parent is protected against the child's "import versioneer". By # removing ourselves from sys.modules here, before the child build # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 + # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - cmds = {} + cmds = {} if cmdclass is None else cmdclass.copy() - # we add "version" to both distutils and setuptools - from distutils.core import Command + # we add "version" to setuptools + from setuptools import Command class cmd_version(Command): description = "report generated version string" - user_options = [] - boolean_options = [] + user_options: List[Tuple[str, str, str]] = [] + boolean_options: List[str] = [] - def initialize_options(self): + def initialize_options(self) -> None: pass - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): + def run(self) -> None: vers = get_versions(verbose=True) print("Version: %s" % vers["version"]) print(" full-revisionid: %s" % vers.get("full-revisionid")) @@ -1523,7 +1918,7 @@ def run(self): print(" error: %s" % vers["error"]) cmds["version"] = cmd_version - # we override "build_py" in both distutils and setuptools + # we override "build_py" in setuptools # # most invocation pathways end up running build_py: # distutils/build -> build_py @@ -1538,18 +1933,25 @@ def run(self): # then does setup.py bdist_wheel, or sometimes setup.py install # setup.py egg_info -> ? + # pip install -e . and setuptool/editable_wheel will invoke build_py + # but the build_py command is not expected to copy any files. + # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py + if 'build_py' in cmds: + _build_py: Any = cmds['build_py'] else: - from distutils.command.build_py import build_py as _build_py + from setuptools.command.build_py import build_py as _build_py class cmd_build_py(_build_py): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() _build_py.run(self) + if getattr(self, "editable_mode", False): + # During editable installs `.py` and data files are + # not copied to build_lib + return # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: @@ -1559,8 +1961,40 @@ def run(self): write_to_version_file(target_versionfile, versions) cmds["build_py"] = cmd_build_py + if 'build_ext' in cmds: + _build_ext: Any = cmds['build_ext'] + else: + from setuptools.command.build_ext import build_ext as _build_ext + + class cmd_build_ext(_build_ext): + def run(self) -> None: + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_ext.run(self) + if self.inplace: + # build_ext --inplace will only build extensions in + # build/lib<..> dir with no _version.py to write to. + # As in place builds will already have a _version.py + # in the module dir, we do not need to write one. + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if not cfg.versionfile_build: + return + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + if not os.path.exists(target_versionfile): + print(f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running build_ext " + "without first running build_py.") + return + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext + if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe + from cx_Freeze.dist import build_exe as _build_exe # type: ignore # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1569,7 +2003,7 @@ def run(self): # ... class cmd_build_exe(_build_exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1593,12 +2027,12 @@ def run(self): if 'py2exe' in sys.modules: # py2exe enabled? try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 + from py2exe.setuptools_buildexe import py2exe as _py2exe # type: ignore except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 + from py2exe.distutils_buildexe import py2exe as _py2exe # type: ignore class cmd_py2exe(_py2exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1619,14 +2053,51 @@ def run(self): }) cmds["py2exe"] = cmd_py2exe + # sdist farms its file list building out to egg_info + if 'egg_info' in cmds: + _egg_info: Any = cmds['egg_info'] + else: + from setuptools.command.egg_info import egg_info as _egg_info + + class cmd_egg_info(_egg_info): + def find_sources(self) -> None: + # egg_info.find_sources builds the manifest list and writes it + # in one shot + super().find_sources() + + # Modify the filelist and normalize it + root = get_root() + cfg = get_config_from_root(root) + self.filelist.append('versioneer.py') + if cfg.versionfile_source: + # There are rare cases where versionfile_source might not be + # included by default, so we must be explicit + self.filelist.append(cfg.versionfile_source) + self.filelist.sort() + self.filelist.remove_duplicates() + + # The write method is hidden in the manifest_maker instance that + # generated the filelist and was thrown away + # We will instead replicate their final normalization (to unicode, + # and POSIX-style paths) + from setuptools import unicode_utils + normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') + for f in self.filelist.files] + + manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') + with open(manifest_filename, 'w') as fobj: + fobj.write('\n'.join(normalized)) + + cmds['egg_info'] = cmd_egg_info + # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist + if 'sdist' in cmds: + _sdist: Any = cmds['sdist'] else: - from distutils.command.sdist import sdist as _sdist + from setuptools.command.sdist import sdist as _sdist class cmd_sdist(_sdist): - def run(self): + def run(self) -> None: versions = get_versions() self._versioneer_generated_versions = versions # unless we update this, the command will keep using the old @@ -1634,7 +2105,7 @@ def run(self): self.distribution.metadata.version = versions["version"] return _sdist.run(self) - def make_release_tree(self, base_dir, files): + def make_release_tree(self, base_dir: str, files: List[str]) -> None: root = get_root() cfg = get_config_from_root(root) _sdist.make_release_tree(self, base_dir, files) @@ -1687,21 +2158,26 @@ def make_release_tree(self, base_dir, files): """ -INIT_PY_SNIPPET = """ +OLD_SNIPPET = """ from ._version import get_versions __version__ = get_versions()['version'] del get_versions """ +INIT_PY_SNIPPET = """ +from . import {0} +__version__ = {0}.get_versions()['version'] +""" -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" + +def do_setup() -> int: + """Do main VCS-independent setup function for installing Versioneer.""" root = get_root() try: cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, + except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): + if isinstance(e, (OSError, configparser.NoSectionError)): print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: @@ -1721,62 +2197,37 @@ def do_setup(): ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") + maybe_ipy: Optional[str] = ipy if os.path.exists(ipy): try: with open(ipy, "r") as f: old = f.read() - except EnvironmentError: + except OSError: old = "" - if INIT_PY_SNIPPET not in old: + module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] + snippet = INIT_PY_SNIPPET.format(module) + if OLD_SNIPPET in old: + print(" replacing boilerplate in %s" % ipy) + with open(ipy, "w") as f: + f.write(old.replace(OLD_SNIPPET, snippet)) + elif snippet not in old: print(" appending to %s" % ipy) with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) + f.write(snippet) else: print(" %s unmodified" % ipy) else: print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") + maybe_ipy = None # Make VCS-specific changes. For git, this means creating/changing # .gitattributes to mark _version.py for export-subst keyword # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + do_vcs_install(cfg.versionfile_source, maybe_ipy) return 0 -def scan_setup_py(): +def scan_setup_py() -> int: """Validate the contents of setup.py against Versioneer's expectations.""" found = set() setters = False @@ -1813,10 +2264,14 @@ def scan_setup_py(): return errors +def setup_command() -> NoReturn: + """Set up Versioneer and exit with appropriate error code.""" + errors = do_setup() + errors += scan_setup_py() + sys.exit(1 if errors else 0) + + if __name__ == "__main__": cmd = sys.argv[1] if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1) \ No newline at end of file + setup_command()