Skip to content

Commit

Permalink
Update plots.py
Browse files Browse the repository at this point in the history
  • Loading branch information
HariOmChadha authored Aug 2, 2024
1 parent 1446155 commit 3cc5b7c
Showing 1 changed file with 110 additions and 2 deletions.
112 changes: 110 additions & 2 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def parity_plot_2(predictions, label_test, model_folder_path, tag):
break
return


def parity_plot_3(predictions, label_test, model_folder_path, tag):
'''Plots the parity plot for the predictions and labels and saves the plot in the model folder'''
# make the folder for the parity plots
Expand Down Expand Up @@ -85,6 +84,29 @@ def parity_plot_3(predictions, label_test, model_folder_path, tag):
plt.close()
return s, m

def plot_colour(y, x):
plt.figure(figsize=(8, 6))
hb = plt.hexbin(x, y, gridsize = 70, cmap = 'viridis', bins = 'log')
cb = plt.colorbar(hb, label = 'Number of points')
cb.set_label('Number of points', fontsize = 13)
cb.ax.tick_params(labelsize = 13, size = 10)
plt.plot(x, x, 'r')
plt.xlabel('Actual ln(Thermal Conductivity)', fontsize = 13)
plt.ylabel('Predicted ln(Thermal Conductivity)', fontsize = 13)
plt.xlim([np.min(x)-0.1, np.max(x)+0.1])
plt.ylim([np.min(x)-0.1, np.max(x)+0.1])
plt.tick_params(labelsize = 14)
plt.title("Parity Plot", fontsize=13)
srcc = scipy.stats.spearmanr(x[:, 3].flatten(), y[:, 3].flatten())[0]
mae = np.mean(np.abs(y[:, 3].flatten() - x[:, 3].flatten()))
textstr = f'MAE: {mae:.2f}\nSRCC: {srcc:.2f}'
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
plt.gca().text(0.05, 0.95, textstr, transform=plt.gca().transAxes, fontsize=14, verticalalignment='top', bbox=props)
plt.grid(True)
plt.savefig('plot2.png', transparent = True)
plt.close()
return

def plot_srcc_MAE(scratch, BT, s_std, BT_std, axis, tag, folder, tag2, tag3=None):
'''Plots the SRCC and MAE values for the scratch and BT models and saves the plot in the folder.
Tag1: Temperature
Expand Down Expand Up @@ -126,4 +148,90 @@ def plot_srcc_MAE(scratch, BT, s_std, BT_std, axis, tag, folder, tag2, tag3=None
plt.legend()
os.makedirs(folder, exist_ok=True)
plt.savefig(f"{folder}/mae_T{tag}.png")
plt.close()
plt.close()

def plot_srcc_MAE_2(scratch, s_std, axis, tag, folder, tag2, tag3=None):
plt.rcParams.update({'font.size': 12})
if tag2 == 'srcc':
plt.plot(axis, scratch, label='random weights', color='red', linestyle='dashed', marker='o')
plt.fill_between(axis, scratch - s_std, scratch + s_std, color='red', alpha=0.2)
# plt.plot(axis, BT, label='weights transfered from visc ANN', color='blue', linestyle='dashed', marker='o')
# plt.fill_between(axis, BT - BT_std, BT + BT_std, color='blue', alpha=0.2)
plt.xlabel('Number of data points')
if tag3 == 'avg':
plt.ylabel('Average SRCC')
plt.title(f'Average SRCC vs Number of train data points at ' + f'$T_{tag}$')
folder = os.path.join(folder, 'srcc_avg_plots')
else:
plt.ylabel('SRCC')
plt.title(f'SRCC vs Number of train data points at ' + f'$T_{tag}$')
folder = os.path.join(folder, 'srcc_plots')
plt.legend()
os.makedirs(folder, exist_ok=True)
plt.savefig(f"{folder}/srcc_T{tag}.png")
plt.close()
elif tag2 == 'mae':
plt.plot(axis, scratch, label='random weights', color='red', linestyle='dashed', marker='o')
plt.fill_between(axis, scratch - s_std, scratch + s_std, color='red', alpha=0.2)
# plt.plot(axis, BT, label='weights transfered from visc ANN', color='blue', linestyle='dashed', marker='o')
# plt.fill_between(axis, BT - BT_std, BT + BT_std, color='blue', alpha=0.2)
plt.xlabel('Number of data points')
if tag3 == 'avg':
plt.ylabel('Average MAE')
plt.title(f'Average MAE vs Number of train data points at ' + f'$T_{tag}$')
folder = os.path.join(folder, 'mae_avg_plots')
else:
plt.ylabel('MAE')
plt.title(f'MAE vs Number of train data points at ' + f'$T_{tag}$')
folder = os.path.join(folder, 'mae_plots')
plt.legend()
os.makedirs(folder, exist_ok=True)
plt.savefig(f"{folder}/mae_T{tag}.png")
plt.close()

def plot_srcc_MAE_3(scratch, BT, s_std, BT_std, axis2, scratch_2, BT_2, s_std_2, BT_std_2, axis, tag, folder, tag2, tag3=None):
plt.rcParams.update({'font.size': 12})
if tag2 == 'srcc':
plt.plot(axis2, scratch, label='random weights', color='red', linestyle='dashed', marker='o')
plt.fill_between(axis2, scratch - s_std, scratch + s_std, color='red', alpha=0.2)
plt.plot(axis2, BT, label='weights transfered from visc ANN', color='blue', linestyle='dashed', marker='o')
plt.fill_between(axis2, BT - BT_std, BT + BT_std, color='blue', alpha=0.2)
plt.plot(axis, scratch_2, label='random weights - GNN', color='green', linestyle='dashed', marker='o')
plt.fill_between(axis, scratch_2 - s_std_2, scratch_2 + s_std_2, color='green', alpha=0.2)
plt.plot(axis, BT_2, label='weights transfered from visc GNN', color='orange', linestyle='dashed', marker='o')
plt.fill_between(axis, BT_2 - BT_std_2, BT_2 + BT_std_2, color='orange', alpha=0.2)
plt.xlabel('Number of data points')
if tag3 == 'avg':
plt.ylabel('Average SRCC')
plt.title(f'Average SRCC vs Number of train data points at ' + f'$T_{tag}$')
folder = os.path.join(folder, 'srcc_avg_plots')
else:
plt.ylabel('SRCC')
plt.title(f'SRCC vs Number of train data points at ' + f'$T_{tag}$')
folder = os.path.join(folder, 'srcc_plots')
plt.legend()
os.makedirs(folder, exist_ok=True)
plt.savefig(f"{folder}/srcc_T{tag}.png")
plt.close()
elif tag2 == 'mae':
plt.plot(axis2, scratch, label='random weights', color='red', linestyle='dashed', marker='o')
plt.fill_between(axis2, scratch - s_std, scratch + s_std, color='red', alpha=0.2)
plt.plot(axis2, BT, label='weights transfered from visc ANN', color='blue', linestyle='dashed', marker='o')
plt.fill_between(axis2, BT - BT_std, BT + BT_std, color='blue', alpha=0.2)
plt.plot(axis, scratch_2, label='random weights - GNN', color='green', linestyle='dashed', marker='o')
plt.fill_between(axis, scratch_2 - s_std_2, scratch_2 + s_std_2, color='green', alpha=0.2)
plt.plot(axis, BT_2, label='weights transfered from visc GNN', color='orange', linestyle='dashed', marker='o')
plt.fill_between(axis, BT_2 - BT_std_2, BT_2 + BT_std_2, color='orange', alpha=0.2)
plt.xlabel('Number of data points')
if tag3 == 'avg':
plt.ylabel('Average MAE')
plt.title(f'Average MAE vs Number of train data points at ' + f'$T_{tag}$')
folder = os.path.join(folder, 'mae_avg_plots')
else:
plt.ylabel('MAE')
plt.title(f'MAE vs Number of train data points at ' + f'$T_{tag}$')
folder = os.path.join(folder, 'mae_plots')
plt.legend()
os.makedirs(folder, exist_ok=True)
plt.savefig(f"{folder}/mae_T{tag}.png")
plt.close()

0 comments on commit 3cc5b7c

Please sign in to comment.