This project aims to classify different types of cancer based on gene expression profiles using various machine learning models, including Random Forest, Support Vector Machine (SVM), and Convolutional Neural Networks (CNN). The dataset used in this project is the BRCA Multi-Omics (TCGA) dataset, which contains gene expression data and associated clinical information.
- Project Overview
- Dataset
- Preprocessing
- Model Development
- Evaluation
- Visualizations
- Future Work
- Requirements
- How to Run the Project
- Name: BRCA Multi-Omics (TCGA)
- Source: Kaggle - BRCA Multi-Omics (TCGA)
- Files:
brca_data_w_subtypes.csv
: Contains gene expression data and associated clinical labels, including histological types of cancer.data.csv
: General data used for supplementary analysis.
- Loading Data: The gene expression data is loaded from
brca_data_w_subtypes.csv
. - Handling Missing Values: Missing values in numeric columns are imputed using the median, while categorical columns are imputed using the most frequent value.
- Feature Scaling: Features are standardized using
StandardScaler
to have zero mean and unit variance. - Feature Selection: Features with low variance are removed using a variance threshold.
- Label Encoding: Categorical labels (histological types) are encoded into integers for compatibility with machine learning models.
- Data Splitting: The dataset is split into training and testing sets with an 80-20 split.
- Description: A robust ensemble learning method that operates by constructing multiple decision trees during training and outputting the class that is the mode of the classes of the individual trees.
- Performance:
- Accuracy: 87%
- Precision, Recall, F1-Score: Precision: 0.86, Recall: 0.87, F1-Score: 0.85
- Description: A supervised machine learning algorithm that can be used for both classification or regression challenges. It performs classification by finding the hyperplane that best divides a dataset into classes.
- Performance:
- Accuracy: 91%
- Precision, Recall, F1-Score: Precision: 0.93, Recall: 0.91, F1-Score: 0.91
- Description: A type of deep learning model, especially effective in capturing spatial dependencies in data. The CNN used here consists of multiple convolutional layers followed by fully connected layers.
- Performance:
- Accuracy: 88%
- Precision, Recall, F1-Score: Precision: 0.88, Recall: 0.88, F1-Score: 0.88
Each model was evaluated on the test set using the following metrics:
- Accuracy: The overall accuracy of the model in predicting the correct cancer type.
- Precision, Recall, F1-Score: precision measures the accuracy of positive predictions, recall measures the model's ability to identify all relevant instances, and F1-Score provides a harmonic mean of precision and recall..
- Confusion Matrix: A matrix showing the true vs. predicted classifications.
All visualizations are saved in the visualizations
folder and include:
- Confusion Matrix: Shows the performance of the classification models in predicting different classes.
- PCA Visualization: A 2D representation of the test data using Principal Component Analysis (PCA).
- Training Loss Curves: Plots of the training and validation loss over epochs for the CNN model.
- Model Tuning: Further tuning of hyperparameters to improve performance, especially for minority classes.
- Advanced Architectures: Experimentation with more advanced neural network architectures like ResNet or DenseNet.
- Class Imbalance Handling: Implement techniques like SMOTE or a weighted loss function to address class imbalance.
- Integration of Clinical Data: Combining gene expression data with other clinical data for more comprehensive models.
- Python 3.x
- PyTorch
- scikit-learn
- pandas
- numpy
- matplotlib
- seaborn
-
Clone the Repository:
git clone https://github.com/yonas650/Cancer-Classification-Gene-Expression-using-svm-randomforest-and-cnn.git cd Cancer-Classification-Gene-Expression-using-svm-randomforest-and-cnn
-
Install Dependencies:
pip install -r requirements.txt
-
Run Preprocessing:
python preprocess_data.py
-
Train and Evaluate Models:
- Random Forest:
python random_forest.py
- SVM:
python svm.py
- CNN:
python neural_network.py
- Random Forest:
-
Review Visualizations:
- All generated visualizations will be saved in the
visualizations
folder.
- All generated visualizations will be saved in the