From cd84929ec48cd21d1e9a9bd901bbd956608b87fe Mon Sep 17 00:00:00 2001 From: renweiluo Date: Thu, 4 Aug 2022 09:41:59 +0800 Subject: [PATCH] simple demo for multi-node processing on a single node --- src/deeperwin/optimization.py | 4 ++-- src/deeperwin/process_molecule.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/deeperwin/optimization.py b/src/deeperwin/optimization.py index b31765a..b5ac27b 100644 --- a/src/deeperwin/optimization.py +++ b/src/deeperwin/optimization.py @@ -122,7 +122,7 @@ def optimize_wavefunction( # Run burn-in of monte carlo chain LOGGER.debug(f"Starting burn-in for optimization: {opt_config.mcmc.n_burn_in} steps") - n_devices = jax.device_count() + n_devices = jax.local_device_count() log_psi_squared_pmapped = jax.pmap(log_psi_squared) mcmc = MetropolisHastingsMonteCarlo(opt_config.mcmc) @@ -216,7 +216,7 @@ def log_psi_squared_func(params, r, R, Z, fixed_params): # Init MCMC logging.debug(f"Starting pretraining...") - n_devices = jax.device_count() + n_devices = jax.local_device_count() mcmc = MetropolisHastingsMonteCarlo(config.mcmc) mcmc_state = MCMCState.resize_or_init(mcmc_state, config.mcmc.n_walkers, phys_config, n_devices) diff --git a/src/deeperwin/process_molecule.py b/src/deeperwin/process_molecule.py index 2cc8214..95735f2 100644 --- a/src/deeperwin/process_molecule.py +++ b/src/deeperwin/process_molecule.py @@ -18,6 +18,10 @@ def process_molecule(config_file): if config.computation.force_device_count and config.computation.n_devices: os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={config.computation.n_devices}' + from jax import distributed + process_id = int(os.environ['CUDA_VISIBLE_DEVICES']) + distributed.initialize('0.0.0.0:8888', 2, process_id) + # These imports can only take place after we have set the jax_config options from jax.config import config as jax_config jax_config.update("jax_enable_x64", config.computation.float_precision == "float64") @@ -38,6 +42,7 @@ def process_molecule(config_file): used_hardware = xla_bridge.get_backend().platform logger.debug(f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}") logger.debug(f"Used hardware: {used_hardware}; Device count: {jax.local_device_count()}") + logger.debug(f"Used hardware in total: {used_hardware}; Device count: {jax.device_count()}") if not config.computation.n_devices: config.computation.n_devices = jax.local_device_count() else: