-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
2 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"cells":[{"cell_type":"markdown","source":["# FINETUNE SEGFORMER on rgb images (pytorch)\n","\n","---\n","<a target=\"_blank\" href=\"https://colab.research.google.com/drive/1FxU8SOoghUwyI-Eza_gPWllHSXyVok5k\">\n"," <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n","</a>\n","\n","https://huggingface.co/docs/transformers/model_doc/segformer"],"metadata":{"id":"FUIYr96QoKoQ"},"id":"FUIYr96QoKoQ"},{"cell_type":"markdown","source":["<img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/segformer_architecture.png\" width=\"600\">\n"],"metadata":{"id":"LDZvoduQLNjI"},"id":"LDZvoduQLNjI"},{"cell_type":"markdown","source":["## Connect do GoogleDrive\n","\n","---\n","\n"],"metadata":{"id":"dC7m1xgHokgv"},"id":"dC7m1xgHokgv"},{"cell_type":"code","execution_count":null,"metadata":{"id":"VlgP4yyUqQWw"},"outputs":[],"source":["from google.colab import drive\n","drive.mount('/content/gdrive', force_remount=True)"],"id":"VlgP4yyUqQWw"},{"cell_type":"markdown","source":["## Install dependencies\n","\n","---\n","\n"],"metadata":{"id":"wdjvW4snowBB"},"id":"wdjvW4snowBB"},{"cell_type":"code","execution_count":null,"metadata":{"id":"nUnQMNXaCiel"},"outputs":[],"source":["! pip install split-folders\n","! pip install -U accelerate\n","! pip install -U transformers\n","! pip install evaluate\n","! pip install rasterio"],"id":"nUnQMNXaCiel"},{"cell_type":"code","execution_count":null,"metadata":{"id":"A2a23Q-xfPgY"},"outputs":[],"source":["import accelerate\n","import transformers\n","\n","print(transformers.__version__, accelerate.__version__)\n","\n","from torch.utils.data import Dataset, DataLoader\n","from transformers import AdamW\n","import torch\n","from torch import nn\n","from sklearn.metrics import accuracy_score\n","from tqdm.notebook import tqdm\n","import os\n","from PIL import Image\n","from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor\n","import pandas as pd\n","import cv2\n","import numpy as np\n","import albumentations as aug\n","import random\n","import rasterio\n","from pathlib import Path\n","import splitfolders\n","import shutil"],"id":"A2a23Q-xfPgY"},{"cell_type":"markdown","source":["## Check GPU Ressources\n","\n","---\n","\n"],"metadata":{"id":"Y6tWQpu3925Y"},"id":"Y6tWQpu3925Y"},{"cell_type":"code","execution_count":null,"metadata":{"id":"OAFPjim0kGYQ"},"outputs":[],"source":["gpu_info = !nvidia-smi\n","gpu_info = '\\n'.join(gpu_info)\n","if gpu_info.find('failed') >= 0:\n"," print('Not connected to a GPU')\n","else:\n"," print(gpu_info)"],"id":"OAFPjim0kGYQ"},{"cell_type":"markdown","source":["## Unzip training data\n","\n","---\n","\n"],"metadata":{"id":"XL6RZox_-PRX"},"id":"XL6RZox_-PRX"},{"cell_type":"code","source":["!unzip /content/gdrive/MyDrive/FLAIR2/flair_2_dataset/flair_labels_train.zip"],"metadata":{"id":"BGLH1OMcBNgK"},"id":"BGLH1OMcBNgK","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gW8ZmnThPD8n"},"outputs":[],"source":["!unzip /content/gdrive/MyDrive/FLAIR2/flair_2_dataset/flair_aerial_train.zip"],"id":"gW8ZmnThPD8n"},{"cell_type":"markdown","source":["## Split data for train/val\n","\n","---\n","\n"],"metadata":{"id":"csO0MuKbocXf"},"id":"csO0MuKbocXf"},{"cell_type":"code","execution_count":null,"metadata":{"id":"05ZgYvr1gIJS"},"outputs":[],"source":["! mkdir \"/content/temp\"\n","! mkdir \"/content/data\"\n","! mkdir \"/content/data/masks\"\n","! mkdir \"/content/data/images\"\n","\n","\n","# Chemin du dossier de destination\n","dst_folder = '/content/temp/'\n","\n","# Chemin du dossier source\n","src_folder = '/content/flair_labels_train'\n","for subdir, dirs, files in os.walk(src_folder):\n"," for file in files:\n"," src_file = os.path.join(subdir, file)\n"," dst_file = os.path.join(dst_folder, file)\n"," shutil.move(src_file, dst_folder)\n","\n","# Chemin du dossier source\n","src_folder = '/content/flair_aerial_train'\n","for subdir, dirs, files in os.walk(src_folder):\n"," for file in files:\n"," src_file = os.path.join(subdir, file)\n"," dst_file = os.path.join(dst_folder, file)\n"," shutil.move(src_file, dst_folder)"],"id":"05ZgYvr1gIJS"},{"cell_type":"code","execution_count":null,"metadata":{"id":"zkw1czotBqJB"},"outputs":[],"source":["cd temp"],"id":"zkw1czotBqJB"},{"cell_type":"code","execution_count":null,"metadata":{"id":"NaIKB7Z85Vs5"},"outputs":[],"source":["!mv MSK*.tif /content/data/masks/\n","!mv IMG*.tif /content/data/images/\n","\n","splitfolders.ratio(\"/content/data/\", seed=1337, ratio=(.998, .002), move=True) # default values"],"id":"NaIKB7Z85Vs5"},{"cell_type":"markdown","source":["## Define a class for the image segmentation dataset\n","\n","---\n","\n"],"metadata":{"id":"rnMuiTWco424"},"id":"rnMuiTWco424"},{"cell_type":"code","execution_count":null,"metadata":{"id":"FNMWyoENzw5i"},"outputs":[],"source":["def get_data_paths (path, filter):\n"," for path in Path(path).rglob(filter):\n"," yield path.resolve().as_posix()"],"id":"FNMWyoENzw5i"},{"cell_type":"code","source":["def rgb2gray(rgb):\n"," return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])"],"metadata":{"id":"dj5Uc6Dm4qMX"},"id":"dj5Uc6Dm4qMX","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Ie3wHlDxfdM4"},"outputs":[],"source":["class ImageSegmentationDataset(Dataset):\n"," \"\"\"Image segmentation dataset.\"\"\"\n","\n","\n"," def __init__(self, root_dir, feature_extractor, transforms=None):\n","\n"," self.root_dir = root_dir\n"," self.feature_extractor = feature_extractor\n","\n"," self.transforms = transforms\n"," self.images = sorted(list(get_data_paths(Path(self.root_dir), 'IMG*.tif')), key=lambda x: int(x.split('_')[-1][:-4]))\n"," self.masks = sorted(list(get_data_paths(Path(self.root_dir), 'MSK*.tif')), key=lambda x: int(x.split('_')[-1][:-4]))\n"," assert len(self.images) == len(self.masks) , \"There must be as many images as there are segmentation maps\"\n","\n"," def read_img(self, raster_file: str) -> np.ndarray:\n"," with rasterio.open(raster_file) as src_img:\n"," rgb = src_img.read([1,2,3]).swapaxes(0, 2).swapaxes(0, 1)\n"," rgb = rgb.astype(np.float32)\n"," return rgb\n","\n"," def read_msk(self, raster_file: str) -> np.ndarray:\n"," with rasterio.open(raster_file) as src_msk:\n"," array = src_msk.read()[0]\n"," array = np.squeeze(array)\n"," return array\n","\n","\n"," def __len__(self):\n"," return len(self.images)\n","\n","\n"," def __getitem__(self, idx):\n","\n"," image_file = self.images[idx]\n"," image = self.read_img(raster_file=image_file)\n"," mask_file = self.masks[idx]\n"," segmentation_map = self.read_msk(raster_file=mask_file)\n"," segmentation_map[segmentation_map > 12] = 0\n","\n","\n"," if self.transforms is not None:\n"," augmented = self.transforms(image=image, mask=segmentation_map)\n"," encoded_inputs = self.feature_extractor(augmented['image'], augmented['mask'], return_tensors=\"pt\")\n"," else:\n"," encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors=\"pt\")\n","\n"," for k,v in encoded_inputs.items():\n"," encoded_inputs[k].squeeze_() # remove batch dimension\n","\n"," return encoded_inputs"],"id":"Ie3wHlDxfdM4"},{"cell_type":"markdown","source":["## Data augmentation with albumentation\n","\n","---\n","\n"],"metadata":{"id":"PNAFPKsA_xmf"},"id":"PNAFPKsA_xmf"},{"cell_type":"code","execution_count":null,"metadata":{"id":"msH36qH_feF4"},"outputs":[],"source":["MEAN = np.array([0.44050665, 0.45704361, 0.42254708])\n","STD = np.array([0.20264351, 0.1782405 , 0.17575739])\n","\n","train_transform = aug.Compose([\n","\n"," aug.Normalize(mean=MEAN, std=STD),\n","\n","])\n","\n","test_transform = aug.Compose([\n"," aug.Normalize(mean=MEAN, std=STD),\n","])\n"],"id":"msH36qH_feF4"},{"cell_type":"code","execution_count":null,"metadata":{"id":"0WdG6bFOffmP"},"outputs":[],"source":["feature_extractor = SegformerFeatureExtractor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)\n","\n","root_train = '/content/temp/output/train'\n","root_val = '/content/temp/output/val'\n","\n","train_dataset = ImageSegmentationDataset(root_dir=root_train, feature_extractor=feature_extractor, transforms=train_transform)\n","valid_dataset = ImageSegmentationDataset(root_dir=root_val, feature_extractor=feature_extractor, transforms=test_transform)"],"id":"0WdG6bFOffmP"},{"cell_type":"code","execution_count":null,"metadata":{"id":"cCxKqT0CfgtQ"},"outputs":[],"source":["print(\"Number of training examples:\", len(train_dataset))\n","print(\"Number of validation examples:\", len(valid_dataset))"],"id":"cCxKqT0CfgtQ"},{"cell_type":"markdown","source":["## Classes metadata\n","\n","---\n","\n"],"metadata":{"id":"re0598WcpDS8"},"id":"re0598WcpDS8"},{"cell_type":"code","source":["def array_to_dict(array):\n"," dictionary = {}\n"," for i, item in enumerate(array):\n"," dictionary[i] = item\n"," return dictionary"],"metadata":{"id":"-j0WDHFfIo5F"},"id":"-j0WDHFfIo5F","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"A-MvayrKhGFC"},"outputs":[],"source":["classes = ['None','building','pervious surface','impervious surface','bare soil','water','coniferous','deciduous','brushwood','vineyard','herbaceous vegetation','agricultural land','plowed land']\n","id2label = array_to_dict(classes)\n","label2id = {v: k for k, v in id2label.items()}"],"id":"A-MvayrKhGFC"},{"cell_type":"code","execution_count":null,"metadata":{"id":"7YN5ArJ3OG-D"},"outputs":[],"source":["num_labels = len(id2label)\n","num_labels"],"id":"7YN5ArJ3OG-D"},{"cell_type":"markdown","metadata":{"id":"Plz_xtW1VXRP"},"source":["# Fine-tune a SegFormer model\n","\n","---\n","\n"],"id":"Plz_xtW1VXRP"},{"cell_type":"markdown","metadata":{"id":"3ci_NXUQV02W"},"source":["## Load the model to fine-tune"],"id":"3ci_NXUQV02W"},{"cell_type":"code","execution_count":null,"metadata":{"id":"U6rz8BE1L8Nb"},"outputs":[],"source":["from transformers import SegformerForSemanticSegmentation\n","\n","pretrained_model_name = \"nvidia/mit-b5\" #@param {type:\"string\"}\n","model = SegformerForSemanticSegmentation.from_pretrained(\n"," pretrained_model_name,\n"," id2label=id2label,\n"," label2id=label2id,\n"," reshape_last_stage=True,\n"," ignore_mismatched_sizes=True\n",")"],"id":"U6rz8BE1L8Nb"},{"cell_type":"markdown","metadata":{"id":"d7nqNiuZV7du"},"source":["## Set up the Trainer\n","\n","\n","---\n","\n"],"id":"d7nqNiuZV7du"},{"cell_type":"code","execution_count":null,"metadata":{"id":"fZJ2HJcyV8uQ"},"outputs":[],"source":["from transformers import TrainingArguments\n","import torch\n","epochs = 8 #@param {type:\"number\"}\n","lr = 6e-5 #@param {type:\"number\"}\n","batch_size = 4 #@param {type:\"number\"}\n","outputdir = \"/content/segformer_b5_igb\" #@param {type:\"string\"}\n","\n","\n","from transformers import TrainingArguments\n","training_args = TrainingArguments(\n"," outputdir,\n"," learning_rate=lr,\n"," num_train_epochs=epochs,\n"," per_device_train_batch_size=batch_size,\n"," per_device_eval_batch_size=2,\n"," evaluation_strategy=\"epoch\",\n"," save_strategy=\"epoch\",\n","\n",")"],"id":"fZJ2HJcyV8uQ"},{"cell_type":"markdown","source":["## Metrics for eval\n","\n","---\n","\n"],"metadata":{"id":"CgtnWDZ8qcBj"},"id":"CgtnWDZ8qcBj"},{"cell_type":"code","execution_count":null,"metadata":{"id":"DKOHOKaOL9Ze"},"outputs":[],"source":["import torch\n","from torch import nn\n","import evaluate\n","import multiprocessing\n","import os\n","\n","\n","metric = evaluate.load(\"mean_iou\")\n","\n","def compute_metrics(eval_pred):\n"," with torch.no_grad():\n"," logits, labels = eval_pred\n"," logits_tensor = torch.from_numpy(logits)\n"," # scale the logits to the size of the label\n"," logits_tensor = nn.functional.interpolate(\n"," logits_tensor,\n"," size=labels.shape[-2:],\n"," mode=\"bilinear\",\n"," align_corners=False,\n"," ).argmax(dim=1)\n","\n"," pred_labels = logits_tensor.detach().cpu().numpy()\n"," metrics = metric._compute(\n"," predictions=pred_labels,\n"," references=labels,\n"," num_labels=len(id2label),\n"," ignore_index=0,\n"," # reduce_labels=feature_extractor.reduce_labels,\n"," )\n","\n"," #add per category metrics as individual key-value pairs\n"," per_category_accuracy = metrics.pop(\"per_category_accuracy\").tolist()\n"," per_category_iou = metrics.pop(\"per_category_iou\").tolist()\n","\n"," metrics.update({f\"accuracy_{id2label[i]}\": v for i, v in enumerate(per_category_accuracy)})\n"," metrics.update({f\"iou_{id2label[i]}\": v for i, v in enumerate(per_category_iou)})\n","\n","\n"," return metrics"],"id":"DKOHOKaOL9Ze"},{"cell_type":"markdown","source":["## Training\n","\n","---\n","\n"],"metadata":{"id":"jGdMub0AqjLS"},"id":"jGdMub0AqjLS"},{"cell_type":"code","execution_count":null,"metadata":{"id":"NmyNBmg2Wacv"},"outputs":[],"source":["from transformers import Trainer\n","\n","trainer = Trainer(\n"," model=model,\n"," args=training_args,\n"," train_dataset=train_dataset,\n"," eval_dataset=valid_dataset,\n"," compute_metrics=compute_metrics,\n",")"],"id":"NmyNBmg2Wacv"},{"cell_type":"code","source":["import warnings\n","warnings.filterwarnings('ignore')"],"metadata":{"id":"QPDA98Nbtk4m"},"id":"QPDA98Nbtk4m","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"7Up9QNqOWtSD"},"outputs":[],"source":["trainer.train()"],"id":"7Up9QNqOWtSD"},{"cell_type":"code","source":["trainer.save_model(\"/content/gdrive/MyDrive/flair-one/models/segformer_b5_igb_8e\")"],"metadata":{"id":"nGZItL-c1aix"},"id":"nGZItL-c1aix","execution_count":null,"outputs":[]}],"metadata":{"colab":{"provenance":[],"machine_shape":"hm","gpuType":"V100","private_outputs":true},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":5} |
Oops, something went wrong.