Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
alanent authored Sep 29, 2023
1 parent fe594ae commit 605959c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions notebooks/train_segformers_rgb_pytorch_aerial_norm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"markdown","source":["# FINETUNE SEGFORMER on rgb images (pytorch)\n","\n","---\n","<a target=\"_blank\" href=\"https://colab.research.google.com/drive/1FxU8SOoghUwyI-Eza_gPWllHSXyVok5k\">\n"," <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n","</a>\n","\n","https://huggingface.co/docs/transformers/model_doc/segformer"],"metadata":{"id":"FUIYr96QoKoQ"},"id":"FUIYr96QoKoQ"},{"cell_type":"markdown","source":["<img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/segformer_architecture.png\" width=\"600\">\n"],"metadata":{"id":"LDZvoduQLNjI"},"id":"LDZvoduQLNjI"},{"cell_type":"markdown","source":["## Connect do GoogleDrive\n","\n","---\n","\n"],"metadata":{"id":"dC7m1xgHokgv"},"id":"dC7m1xgHokgv"},{"cell_type":"code","execution_count":null,"metadata":{"id":"VlgP4yyUqQWw"},"outputs":[],"source":["from google.colab import drive\n","drive.mount('/content/gdrive', force_remount=True)"],"id":"VlgP4yyUqQWw"},{"cell_type":"markdown","source":["## Install dependencies\n","\n","---\n","\n"],"metadata":{"id":"wdjvW4snowBB"},"id":"wdjvW4snowBB"},{"cell_type":"code","execution_count":null,"metadata":{"id":"nUnQMNXaCiel"},"outputs":[],"source":["! pip install split-folders\n","! pip install -U accelerate\n","! pip install -U transformers\n","! pip install evaluate\n","! pip install rasterio"],"id":"nUnQMNXaCiel"},{"cell_type":"code","execution_count":null,"metadata":{"id":"A2a23Q-xfPgY"},"outputs":[],"source":["import accelerate\n","import transformers\n","\n","print(transformers.__version__, accelerate.__version__)\n","\n","from torch.utils.data import Dataset, DataLoader\n","from transformers import AdamW\n","import torch\n","from torch import nn\n","from sklearn.metrics import accuracy_score\n","from tqdm.notebook import tqdm\n","import os\n","from PIL import Image\n","from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor\n","import pandas as pd\n","import cv2\n","import numpy as np\n","import albumentations as aug\n","import random\n","import rasterio\n","from pathlib import Path\n","import splitfolders\n","import shutil"],"id":"A2a23Q-xfPgY"},{"cell_type":"markdown","source":["## Check GPU Ressources\n","\n","---\n","\n"],"metadata":{"id":"Y6tWQpu3925Y"},"id":"Y6tWQpu3925Y"},{"cell_type":"code","execution_count":null,"metadata":{"id":"OAFPjim0kGYQ"},"outputs":[],"source":["gpu_info = !nvidia-smi\n","gpu_info = '\\n'.join(gpu_info)\n","if gpu_info.find('failed') >= 0:\n"," print('Not connected to a GPU')\n","else:\n"," print(gpu_info)"],"id":"OAFPjim0kGYQ"},{"cell_type":"markdown","source":["## Unzip training data\n","\n","---\n","\n"],"metadata":{"id":"XL6RZox_-PRX"},"id":"XL6RZox_-PRX"},{"cell_type":"code","source":["!unzip /content/gdrive/MyDrive/FLAIR2/flair_2_dataset/flair_labels_train.zip"],"metadata":{"id":"BGLH1OMcBNgK"},"id":"BGLH1OMcBNgK","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gW8ZmnThPD8n"},"outputs":[],"source":["!unzip /content/gdrive/MyDrive/FLAIR2/flair_2_dataset/flair_aerial_train.zip"],"id":"gW8ZmnThPD8n"},{"cell_type":"markdown","source":["## Split data for train/val\n","\n","---\n","\n"],"metadata":{"id":"csO0MuKbocXf"},"id":"csO0MuKbocXf"},{"cell_type":"code","execution_count":null,"metadata":{"id":"05ZgYvr1gIJS"},"outputs":[],"source":["! mkdir \"/content/temp\"\n","! mkdir \"/content/data\"\n","! mkdir \"/content/data/masks\"\n","! mkdir \"/content/data/images\"\n","\n","\n","# Chemin du dossier de destination\n","dst_folder = '/content/temp/'\n","\n","# Chemin du dossier source\n","src_folder = '/content/flair_labels_train'\n","for subdir, dirs, files in os.walk(src_folder):\n"," for file in files:\n"," src_file = os.path.join(subdir, file)\n"," dst_file = os.path.join(dst_folder, file)\n"," shutil.move(src_file, dst_folder)\n","\n","# Chemin du dossier source\n","src_folder = '/content/flair_aerial_train'\n","for subdir, dirs, files in os.walk(src_folder):\n"," for file in files:\n"," src_file = os.path.join(subdir, file)\n"," dst_file = os.path.join(dst_folder, file)\n"," shutil.move(src_file, dst_folder)"],"id":"05ZgYvr1gIJS"},{"cell_type":"code","execution_count":null,"metadata":{"id":"zkw1czotBqJB"},"outputs":[],"source":["cd temp"],"id":"zkw1czotBqJB"},{"cell_type":"code","execution_count":null,"metadata":{"id":"NaIKB7Z85Vs5"},"outputs":[],"source":["!mv MSK*.tif /content/data/masks/\n","!mv IMG*.tif /content/data/images/\n","\n","splitfolders.ratio(\"/content/data/\", seed=1337, ratio=(.998, .002), move=True) # default values"],"id":"NaIKB7Z85Vs5"},{"cell_type":"markdown","source":["## Define a class for the image segmentation dataset\n","\n","---\n","\n"],"metadata":{"id":"rnMuiTWco424"},"id":"rnMuiTWco424"},{"cell_type":"code","execution_count":null,"metadata":{"id":"FNMWyoENzw5i"},"outputs":[],"source":["def get_data_paths (path, filter):\n"," for path in Path(path).rglob(filter):\n"," yield path.resolve().as_posix()"],"id":"FNMWyoENzw5i"},{"cell_type":"code","source":["def rgb2gray(rgb):\n"," return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])"],"metadata":{"id":"dj5Uc6Dm4qMX"},"id":"dj5Uc6Dm4qMX","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Ie3wHlDxfdM4"},"outputs":[],"source":["class ImageSegmentationDataset(Dataset):\n"," \"\"\"Image segmentation dataset.\"\"\"\n","\n","\n"," def __init__(self, root_dir, feature_extractor, transforms=None):\n","\n"," self.root_dir = root_dir\n"," self.feature_extractor = feature_extractor\n","\n"," self.transforms = transforms\n"," self.images = sorted(list(get_data_paths(Path(self.root_dir), 'IMG*.tif')), key=lambda x: int(x.split('_')[-1][:-4]))\n"," self.masks = sorted(list(get_data_paths(Path(self.root_dir), 'MSK*.tif')), key=lambda x: int(x.split('_')[-1][:-4]))\n"," assert len(self.images) == len(self.masks) , \"There must be as many images as there are segmentation maps\"\n","\n"," def read_img(self, raster_file: str) -> np.ndarray:\n"," with rasterio.open(raster_file) as src_img:\n"," rgb = src_img.read([1,2,3]).swapaxes(0, 2).swapaxes(0, 1)\n"," rgb = rgb.astype(np.float32)\n"," return rgb\n","\n"," def read_msk(self, raster_file: str) -> np.ndarray:\n"," with rasterio.open(raster_file) as src_msk:\n"," array = src_msk.read()[0]\n"," array = np.squeeze(array)\n"," return array\n","\n","\n"," def __len__(self):\n"," return len(self.images)\n","\n","\n"," def __getitem__(self, idx):\n","\n"," image_file = self.images[idx]\n"," image = self.read_img(raster_file=image_file)\n"," mask_file = self.masks[idx]\n"," segmentation_map = self.read_msk(raster_file=mask_file)\n"," segmentation_map[segmentation_map > 12] = 0\n","\n","\n"," if self.transforms is not None:\n"," augmented = self.transforms(image=image, mask=segmentation_map)\n"," encoded_inputs = self.feature_extractor(augmented['image'], augmented['mask'], return_tensors=\"pt\")\n"," else:\n"," encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors=\"pt\")\n","\n"," for k,v in encoded_inputs.items():\n"," encoded_inputs[k].squeeze_() # remove batch dimension\n","\n"," return encoded_inputs"],"id":"Ie3wHlDxfdM4"},{"cell_type":"markdown","source":["## Data augmentation with albumentation\n","\n","---\n","\n"],"metadata":{"id":"PNAFPKsA_xmf"},"id":"PNAFPKsA_xmf"},{"cell_type":"code","execution_count":null,"metadata":{"id":"msH36qH_feF4"},"outputs":[],"source":["MEAN = np.array([0.44050665, 0.45704361, 0.42254708])\n","STD = np.array([0.20264351, 0.1782405 , 0.17575739])\n","\n","train_transform = aug.Compose([\n","\n"," aug.Normalize(mean=MEAN, std=STD),\n","\n","])\n","\n","test_transform = aug.Compose([\n"," aug.Normalize(mean=MEAN, std=STD),\n","])\n"],"id":"msH36qH_feF4"},{"cell_type":"code","execution_count":null,"metadata":{"id":"0WdG6bFOffmP"},"outputs":[],"source":["feature_extractor = SegformerFeatureExtractor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)\n","\n","root_train = '/content/temp/output/train'\n","root_val = '/content/temp/output/val'\n","\n","train_dataset = ImageSegmentationDataset(root_dir=root_train, feature_extractor=feature_extractor, transforms=train_transform)\n","valid_dataset = ImageSegmentationDataset(root_dir=root_val, feature_extractor=feature_extractor, transforms=test_transform)"],"id":"0WdG6bFOffmP"},{"cell_type":"code","execution_count":null,"metadata":{"id":"cCxKqT0CfgtQ"},"outputs":[],"source":["print(\"Number of training examples:\", len(train_dataset))\n","print(\"Number of validation examples:\", len(valid_dataset))"],"id":"cCxKqT0CfgtQ"},{"cell_type":"markdown","source":["## Classes metadata\n","\n","---\n","\n"],"metadata":{"id":"re0598WcpDS8"},"id":"re0598WcpDS8"},{"cell_type":"code","source":["def array_to_dict(array):\n"," dictionary = {}\n"," for i, item in enumerate(array):\n"," dictionary[i] = item\n"," return dictionary"],"metadata":{"id":"-j0WDHFfIo5F"},"id":"-j0WDHFfIo5F","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"A-MvayrKhGFC"},"outputs":[],"source":["classes = ['None','building','pervious surface','impervious surface','bare soil','water','coniferous','deciduous','brushwood','vineyard','herbaceous vegetation','agricultural land','plowed land']\n","id2label = array_to_dict(classes)\n","label2id = {v: k for k, v in id2label.items()}"],"id":"A-MvayrKhGFC"},{"cell_type":"code","execution_count":null,"metadata":{"id":"7YN5ArJ3OG-D"},"outputs":[],"source":["num_labels = len(id2label)\n","num_labels"],"id":"7YN5ArJ3OG-D"},{"cell_type":"markdown","metadata":{"id":"Plz_xtW1VXRP"},"source":["# Fine-tune a SegFormer model\n","\n","---\n","\n"],"id":"Plz_xtW1VXRP"},{"cell_type":"markdown","metadata":{"id":"3ci_NXUQV02W"},"source":["## Load the model to fine-tune"],"id":"3ci_NXUQV02W"},{"cell_type":"code","execution_count":null,"metadata":{"id":"U6rz8BE1L8Nb"},"outputs":[],"source":["from transformers import SegformerForSemanticSegmentation\n","\n","pretrained_model_name = \"nvidia/mit-b5\" #@param {type:\"string\"}\n","model = SegformerForSemanticSegmentation.from_pretrained(\n"," pretrained_model_name,\n"," id2label=id2label,\n"," label2id=label2id,\n"," reshape_last_stage=True,\n"," ignore_mismatched_sizes=True\n",")"],"id":"U6rz8BE1L8Nb"},{"cell_type":"markdown","metadata":{"id":"d7nqNiuZV7du"},"source":["## Set up the Trainer\n","\n","\n","---\n","\n"],"id":"d7nqNiuZV7du"},{"cell_type":"code","execution_count":null,"metadata":{"id":"fZJ2HJcyV8uQ"},"outputs":[],"source":["from transformers import TrainingArguments\n","import torch\n","epochs = 8 #@param {type:\"number\"}\n","lr = 6e-5 #@param {type:\"number\"}\n","batch_size = 4 #@param {type:\"number\"}\n","outputdir = \"/content/segformer_b5_igb\" #@param {type:\"string\"}\n","\n","\n","from transformers import TrainingArguments\n","training_args = TrainingArguments(\n"," outputdir,\n"," learning_rate=lr,\n"," num_train_epochs=epochs,\n"," per_device_train_batch_size=batch_size,\n"," per_device_eval_batch_size=2,\n"," evaluation_strategy=\"epoch\",\n"," save_strategy=\"epoch\",\n","\n",")"],"id":"fZJ2HJcyV8uQ"},{"cell_type":"markdown","source":["## Metrics for eval\n","\n","---\n","\n"],"metadata":{"id":"CgtnWDZ8qcBj"},"id":"CgtnWDZ8qcBj"},{"cell_type":"code","execution_count":null,"metadata":{"id":"DKOHOKaOL9Ze"},"outputs":[],"source":["import torch\n","from torch import nn\n","import evaluate\n","import multiprocessing\n","import os\n","\n","\n","metric = evaluate.load(\"mean_iou\")\n","\n","def compute_metrics(eval_pred):\n"," with torch.no_grad():\n"," logits, labels = eval_pred\n"," logits_tensor = torch.from_numpy(logits)\n"," # scale the logits to the size of the label\n"," logits_tensor = nn.functional.interpolate(\n"," logits_tensor,\n"," size=labels.shape[-2:],\n"," mode=\"bilinear\",\n"," align_corners=False,\n"," ).argmax(dim=1)\n","\n"," pred_labels = logits_tensor.detach().cpu().numpy()\n"," metrics = metric._compute(\n"," predictions=pred_labels,\n"," references=labels,\n"," num_labels=len(id2label),\n"," ignore_index=0,\n"," # reduce_labels=feature_extractor.reduce_labels,\n"," )\n","\n"," #add per category metrics as individual key-value pairs\n"," per_category_accuracy = metrics.pop(\"per_category_accuracy\").tolist()\n"," per_category_iou = metrics.pop(\"per_category_iou\").tolist()\n","\n"," metrics.update({f\"accuracy_{id2label[i]}\": v for i, v in enumerate(per_category_accuracy)})\n"," metrics.update({f\"iou_{id2label[i]}\": v for i, v in enumerate(per_category_iou)})\n","\n","\n"," return metrics"],"id":"DKOHOKaOL9Ze"},{"cell_type":"markdown","source":["## Training\n","\n","---\n","\n"],"metadata":{"id":"jGdMub0AqjLS"},"id":"jGdMub0AqjLS"},{"cell_type":"code","execution_count":null,"metadata":{"id":"NmyNBmg2Wacv"},"outputs":[],"source":["from transformers import Trainer\n","\n","trainer = Trainer(\n"," model=model,\n"," args=training_args,\n"," train_dataset=train_dataset,\n"," eval_dataset=valid_dataset,\n"," compute_metrics=compute_metrics,\n",")"],"id":"NmyNBmg2Wacv"},{"cell_type":"code","source":["import warnings\n","warnings.filterwarnings('ignore')"],"metadata":{"id":"QPDA98Nbtk4m"},"id":"QPDA98Nbtk4m","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"7Up9QNqOWtSD"},"outputs":[],"source":["trainer.train()"],"id":"7Up9QNqOWtSD"},{"cell_type":"code","source":["trainer.save_model(\"/content/gdrive/MyDrive/flair-one/models/segformer_b5_igb_8e\")"],"metadata":{"id":"nGZItL-c1aix"},"id":"nGZItL-c1aix","execution_count":null,"outputs":[]}],"metadata":{"colab":{"provenance":[],"machine_shape":"hm","gpuType":"V100","private_outputs":true},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":5}
Loading

0 comments on commit 605959c

Please sign in to comment.