Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatible w/ PyTorch 2 &co. but incompatible with Mac M2 #3

Open
matteoferla opened this issue Jun 3, 2023 · 0 comments
Open

Compatible w/ PyTorch 2 &co. but incompatible with Mac M2 #3

matteoferla opened this issue Jun 3, 2023 · 0 comments

Comments

@matteoferla
Copy link

matteoferla commented Jun 3, 2023

Aim: Get CoPriNet to work on a Mac M2 processor. This is not CPU for which CoPriNet w/ regular PyTorch v. 1.1 ought to work with the minor addition of the parameter map_location=lambda storage, loc: storage, to the method PricePredictorModule.load_from_checkpoint. Instead an Apple Silicon M2 has a special accelerator, "metal", whose accelerator name is mps.

Running PL2

On a CUDA-compatible Linux machine it works with the latest PyTorch with minor tweaks.
I thought I best share my hacky fix even if I am out of my depth here and probs stating the obvious!

from pricePrediction.predict.predict import *


class ModGraphPricePredictor(GraphPricePredictor):

    def __init__(self,
                 model_path=DEFAULT_MODEL,
                 devices=1,
                 accelerator="auto",
                 map_location=None,
                 n_cpus= NUM_WORKERS_PER_GPU,
                 batch_size: int = BATCH_SIZE,
                 strict=True,
                 **kwargs):
       # added jankily some extras for my troubleshooting
        self.model_path = model_path
        self.devices = int(devices)
        self.n_gpus = self.devices
        self.accelerator = accelerator
        self.map_location = map_location
        self.batch_size = batch_size
        self.strict = strict
        self.n_cpus = n_cpus if sys.platform == 'linux'  else 0   # controls multiprocessing
        self.trainer = pl.Trainer(accelerator=self.accelerator, devices=self.devices, logger=False)
        self.model = PricePredictorModule.load_from_checkpoint(self.model_path,
                                                               self.map_location,
                                                               batch_size=self.batch_size,
                                                               strict=self.strict)
        if USE_FEATURES_NET:
            from pricePrediction.preprocessData.smilesToDescriptors import smiles_to_graph
        else:
            from pricePrediction.preprocessData.smilesToGraph import smiles_to_graph
        self.smiles_to_graph = smiles_to_graph

    def yieldPredictions(self, smiles_generator, buffer_n_batches=BUFFER_N_BATCHES_FOR_PRED):
        # No issues with this method, here because it's the workhorse callable of the class.
        buffer_size = buffer_n_batches * self.batch_size
        preds_iter = map(self.predictListOfSmiles, tqdm(chunked(smiles_generator, buffer_size)))
        for preds_batch in preds_iter:
            for pred in preds_batch:
                yield pred

    def predictListOfSmiles(self, smiles_list):
        print(smiles_list)
        graphs_list = list(filter(None.__ne__, map(self.prepare_smi, enumerate(smiles_list))))

        def graphs_fn():
            # MOD: lambdas don't pickle in Windows and Mac
            return graphs_list

        dataset = MyIterableDataset(graphs_fn, self.n_cpus)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, collate_fn=Batch.from_data_list,
                                num_workers=self.n_cpus)
        preds = self.trainer.predict(self.model, dataloader)
        n_smiles = len(smiles_list)
        all_preds = np.nan * np.ones(n_smiles)
        for i, batch in enumerate(dataloader):
            batch_preds = preds[i].to("cpu").numpy()
            idxs = batch.input_idx.to("cpu").numpy().astype(np.int64).tolist()
            all_preds[idxs] = batch_preds
        return all_preds

# ---------------------------------------
from pricePrediction.predict import predict

smiles_list = ['c1cc2c(cc1C(=O)O)C3(c4ccc(cc4Oc5c3ccc(c5)O)O)OC2=O', 'OC(=O)c1cc2cc(F)c(O)c(F)c2oc1=O']
predictor = ModGraphPricePredictor(model_path=predict.DEFAULT_MODEL,
                                         devices=1,
                                         accelarator='cuda',
                                         strict=False,
                                         batch_size=predict.BATCH_SIZE
                                       )
preds = predictor.yieldPredictions(smiles_list)
for smi, pred in zip(smiles_list, preds):
    print("%s\t%.4f" % (smi, pred))
c1cc2c(cc1C(=O)O)C3(c4ccc(cc4Oc5c3ccc(c5)O)O)OC2=O	102.5534
OC(=O)c1cc2cc(F)c(O)c(F)c2oc1=O	80.1875

The molecules are fluorescein and pacific blue and probably not a clever choice as they are made at scale, but those are numbers!

pl.Train

The Train class instantiation in version 2 of PyTorch Lightning differs, as it accepts accelarator and devices instead.
I think the accelerator options are the same as spat out by the error message of torch.device('this-is-not-a-valid-device-name').

Checkpoint & possible multimodel checkpoint collisions

Cuda device checkpoints need special care opening on other devices, namely a map_location during loading set to map_location=lambda storage, loc: storage. Otherwise, None automatically assigns a torch.device('cpu') or torch.device('cuda').

The strict argument is needed for the checkpoint loading (PricePredictorModule.load_from_checkpoint) otherwise the following warning is fatal:

/vols/opig/apps/conda310/envs/cuda39/lib/python3.9/site-packages/pytorch_lightning/core/saving.py:158: UserWarning: Found keys that are in the model state dict but not in the checkpoint: ['net.convs.0.aggr_module.avg_deg_lin', 'net.convs.0.aggr_module.avg_deg_log', 'net.convs.1.aggr_module.avg_deg_lin', 'net.convs.1.aggr_module.avg_deg_log', 'net.convs.2.aggr_module.avg_deg_lin', 'net.convs.2.aggr_module.avg_deg_log', 'net.convs.3.aggr_module.avg_deg_lin', 'net.convs.3.aggr_module.avg_deg_log', 'net.convs.4.aggr_module.avg_deg_lin', 'net.convs.4.aggr_module.avg_deg_log', 'net.convs.5.aggr_module.avg_deg_lin', 'net.convs.5.aggr_module.avg_deg_log', 'net.convs.6.aggr_module.avg_deg_lin', 'net.convs.6.aggr_module.avg_deg_log', 'net.convs.7.aggr_module.avg_deg_lin', 'net.convs.7.aggr_module.avg_deg_log', 'net.convs.8.aggr_module.avg_deg_lin', 'net.convs.8.aggr_module.avg_deg_log', 'net.convs.9.aggr_module.avg_deg_lin', 'net.convs.9.aggr_module.avg_deg_log']

This happens even after doing

python -m pytorch_lightning.utilities.upgrade_checkpoint --file data/models/trained_inStock/lightning_logs/version_0/checkpoints/epoch=238-step=1295857.ckpt

Mac M2

The classic annoyance of pickling lambdas on a Mac happens —this is needed for multiprocessing module. So graph_fn in predictListOfSmiles needs to be made not a lambda.

On a Mac M2 processor (torch.device("mps")), the code ultimately stumbles due to aten::scatter_reduce.two_out in latest PyTorch not being implemented for MPS as listed in pytorch/pytorch#77764.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant