-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cleaned up dir, added documentation, added
prediction files
- Loading branch information
Showing
18 changed files
with
1,785 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,59 @@ | ||
# RDP-Net | ||
Implemention of RDP-Net using Python's TensorFlow | ||
|
||
This code was used apart of a study that is yet to be published. The code may be used according to it's MIT license. The data used in the study may be provided upon request. If you are interested, please contact the author via email: [email protected] | ||
## *Setup* | ||
|
||
1. Clone the directory to your local machine using the following command: | ||
``` | ||
$ git clone https://github.com/dalmouiee/RDP-Net-Ret.git | ||
``` | ||
|
||
2. Run the create_py_venv.sh script to setup the appropriate python virtual environment: </br> | ||
``` | ||
$ sh scripts\create_py_venv.sh | ||
``` | ||
3. Once this script has been executed, the 'py_venv' directory should appear, containing the virtual environment. Type in the following command to activate and switch into it: | ||
|
||
For Windows: | ||
``` | ||
$ py_venv\Scripts\activate.bat | ||
``` | ||
For Mac/Linux: | ||
``` | ||
$ source py_venv/bin/activate | ||
``` | ||
You should see (py_venv) added to the beginning of the command line like so: | ||
``` | ||
(py_venv) $ | ||
``` | ||
If you wish to exit this virtual environemnt, type: | ||
``` | ||
(py_venv) $ deactivate | ||
``` | ||
4. Next, run the get_py_libs.sh script to install the necessary python libraries, needed to run the application locally: | ||
``` | ||
(py_venv) $ sh scripts\get_py_libs.sh | ||
``` | ||
This may take a few minutes to complete. | ||
### *Training* | ||
To run training, navigate to the architecture's source code directory and run the training script (refer to the train.py for more information on running the script): | ||
``` | ||
(py_venv) $ cd src\prototype\djd | ||
(py_venv) $ python train.py PATH_TO_TRAINING_DATA | ||
``` | ||
### *Inference/Prediction* | ||
To run predict new/test images using a trained model, navigate to the architecture's source code directory and run the predict script (refer to the train.py for more information on running the script): | ||
``` | ||
(py_venv) $ cd src\prototype\djd | ||
(py_venv) $ python predict_model_images.py PATH_TO_TESTING_SET PATH_TO_METAFILE_WITH_ARCHITECTURE_NAME NAME_OF_CHECKPOINT_FILE | ||
``` | ||
This code was used apart of a study that has yet to be published. The code may be used according to it's MIT license. The data used in the study may be provided upon request. If you are interested, please contact the author via email: [email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
''' | ||
Author: Daniel Al Mouiee | ||
Date: 09/06/2021 | ||
Script to predict new/test images using a trained CNN architecture | ||
Usage: | ||
python predict.py PATH_TO_TESTING_SET PATH_TO_METAFILE_WITH_ARCHITECTURE_NAME NAME_OF_CHECKPOINT_FILE | ||
ie. python code/arch_01/predict_model_images.py data/testing/ code/arch_01/ model3 | ||
''' | ||
|
||
from scipy import interp | ||
from sklearn.preprocessing import label_binarize | ||
from sklearn.metrics import roc_auc_score | ||
from sklearn.metrics import roc_curve, auc | ||
from sklearn.metrics import classification_report | ||
from sklearn.metrics import confusion_matrix | ||
import tensorflow as tf | ||
import numpy as np | ||
import os | ||
import glob | ||
import cv2 | ||
import sys | ||
import argparse | ||
import matplotlib.pyplot as plt | ||
import tensorflow.compat.v1 as tf | ||
tf.disable_v2_behavior() | ||
|
||
|
||
def plot_confusion_matrix(y_true, y_pred, titleName, classes, normalize=True, cmap=plt.cm.Blues): | ||
""" | ||
This function prints and plots the confusion matrix. | ||
Normalization can be applied by setting `normalize=True`. | ||
""" | ||
title = titleName | ||
|
||
# Compute confusion matrix | ||
cm = confusion_matrix(y_true, y_pred) | ||
# Only use the labels that appear in the data | ||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | ||
|
||
fig, ax = plt.subplots() | ||
im = ax.imshow(cm, interpolation='nearest', cmap=cmap) | ||
ax.figure.colorbar(im, ax=ax) | ||
# We want to show all ticks... | ||
ax.set(xticks=np.arange(cm.shape[1]), | ||
yticks=np.arange(cm.shape[0]), | ||
# ... and label them with the respective list entries | ||
xticklabels=classes, yticklabels=classes, | ||
title=title, | ||
ylabel='True label', | ||
xlabel='Michael\'s predicted label') | ||
|
||
# Rotate the tick labels and set their alignment. | ||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", | ||
rotation_mode="anchor") | ||
|
||
# Loop over data dimensions and create text annotations. | ||
fmt = '.2f' if normalize else 'd' | ||
thresh = cm.max() / 2. | ||
for i in range(cm.shape[0]): | ||
for j in range(cm.shape[1]): | ||
ax.text(j, i, format(cm[i, j], fmt), | ||
ha="center", va="center", | ||
color="white" if cm[i, j] > thresh else "black") | ||
fig.tight_layout() | ||
plt.show() | ||
return ax | ||
|
||
|
||
# First, pass the path of the image | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
testPath = sys.argv[1] | ||
image_size = 128 | ||
num_channels = 3 | ||
images = [] | ||
succs = 0 | ||
classes = ['1', '2', '3', '4'] | ||
true = [] | ||
predicted = [] | ||
probs = np.array([]) | ||
path = os.path.join(testPath, '*g') | ||
files = glob.glob(path) | ||
for f in files: | ||
|
||
# Reading the image using OpenCV | ||
image = cv2.imread(f) | ||
# Resizing the image to our desired size and preprocessing will be done exactly as done during training | ||
image = cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) | ||
images = [] | ||
images.append(image) | ||
images = np.array(images, dtype=np.uint8) | ||
images = images.astype('float32') | ||
images = np.multiply(images, 1.0/255.0) | ||
# The input to the network is of shape [None image_size image_size num_channels]. Hence we reshape. | ||
x_batch = images.reshape(1, image_size, image_size, num_channels) | ||
|
||
# Let us restore the saved model | ||
sess = tf.Session() | ||
# Step-1: Recreate the network graph. At this step only graph is created. | ||
saver = tf.train.import_meta_graph(dir_path+'/'+sys.argv[3]+'.meta') | ||
# Step-2: Now let's load the weights saved using the restore method. | ||
saver.restore(sess, tf.train.latest_checkpoint(sys.argv[2])) | ||
|
||
# Accessing the default graph which we have restored | ||
graph = tf.get_default_graph() | ||
|
||
# Now, let's get hold of the op that we can be processed to get the output. | ||
# In the original network y_pred is the tensor that is the prediction of the network | ||
y_pred = graph.get_tensor_by_name("y_pred:0") | ||
|
||
# Let's feed the images to the input placeholders | ||
x = graph.get_tensor_by_name("x:0") | ||
y_true = graph.get_tensor_by_name("y_true:0") | ||
y_test_images = np.zeros((1, len(classes))) | ||
|
||
# Creating the feed_dict that is required to be fed to calculate y_pred | ||
feed_dict_testing = {x: x_batch, y_true: y_test_images} | ||
result = sess.run(y_pred, feed_dict=feed_dict_testing) | ||
|
||
probs = np.append(probs, result[0], axis=0) | ||
tempList = list(result[0]) | ||
predicted.append(str(tempList.index(max(tempList))+1)) | ||
# result is of this format [probabiliy_of_blind probability_of_normal] | ||
print('File: '+f+', class: ', str(tempList.index(max(tempList))+1)) |
Oops, something went wrong.