Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] simple demo for multi-node processing on a single node #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/deeperwin/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def optimize_wavefunction(

# Run burn-in of monte carlo chain
LOGGER.debug(f"Starting burn-in for optimization: {opt_config.mcmc.n_burn_in} steps")
n_devices = jax.device_count()
n_devices = jax.local_device_count()
log_psi_squared_pmapped = jax.pmap(log_psi_squared)

mcmc = MetropolisHastingsMonteCarlo(opt_config.mcmc)
Expand Down Expand Up @@ -216,7 +216,7 @@ def log_psi_squared_func(params, r, R, Z, fixed_params):

# Init MCMC
logging.debug(f"Starting pretraining...")
n_devices = jax.device_count()
n_devices = jax.local_device_count()
mcmc = MetropolisHastingsMonteCarlo(config.mcmc)
mcmc_state = MCMCState.resize_or_init(mcmc_state, config.mcmc.n_walkers, phys_config, n_devices)

Expand Down
5 changes: 5 additions & 0 deletions src/deeperwin/process_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def process_molecule(config_file):
if config.computation.force_device_count and config.computation.n_devices:
os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={config.computation.n_devices}'

from jax import distributed
process_id = int(os.environ['CUDA_VISIBLE_DEVICES'])
distributed.initialize('0.0.0.0:8888', 2, process_id)

# These imports can only take place after we have set the jax_config options
from jax.config import config as jax_config
jax_config.update("jax_enable_x64", config.computation.float_precision == "float64")
Expand All @@ -38,6 +42,7 @@ def process_molecule(config_file):
used_hardware = xla_bridge.get_backend().platform
logger.debug(f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
logger.debug(f"Used hardware: {used_hardware}; Device count: {jax.local_device_count()}")
logger.debug(f"Used hardware in total: {used_hardware}; Device count: {jax.device_count()}")
if not config.computation.n_devices:
config.computation.n_devices = jax.local_device_count()
else:
Expand Down