From 3cc5b7c7d2f16334fdbac2b29916d9ac1f4e3840 Mon Sep 17 00:00:00 2001 From: Hari Om Chadha Date: Fri, 2 Aug 2024 11:36:53 -0400 Subject: [PATCH] Update plots.py --- utils/plots.py | 112 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 2 deletions(-) diff --git a/utils/plots.py b/utils/plots.py index 52d8480..2c4bc92 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -50,7 +50,6 @@ def parity_plot_2(predictions, label_test, model_folder_path, tag): break return - def parity_plot_3(predictions, label_test, model_folder_path, tag): '''Plots the parity plot for the predictions and labels and saves the plot in the model folder''' # make the folder for the parity plots @@ -85,6 +84,29 @@ def parity_plot_3(predictions, label_test, model_folder_path, tag): plt.close() return s, m +def plot_colour(y, x): + plt.figure(figsize=(8, 6)) + hb = plt.hexbin(x, y, gridsize = 70, cmap = 'viridis', bins = 'log') + cb = plt.colorbar(hb, label = 'Number of points') + cb.set_label('Number of points', fontsize = 13) + cb.ax.tick_params(labelsize = 13, size = 10) + plt.plot(x, x, 'r') + plt.xlabel('Actual ln(Thermal Conductivity)', fontsize = 13) + plt.ylabel('Predicted ln(Thermal Conductivity)', fontsize = 13) + plt.xlim([np.min(x)-0.1, np.max(x)+0.1]) + plt.ylim([np.min(x)-0.1, np.max(x)+0.1]) + plt.tick_params(labelsize = 14) + plt.title("Parity Plot", fontsize=13) + srcc = scipy.stats.spearmanr(x[:, 3].flatten(), y[:, 3].flatten())[0] + mae = np.mean(np.abs(y[:, 3].flatten() - x[:, 3].flatten())) + textstr = f'MAE: {mae:.2f}\nSRCC: {srcc:.2f}' + props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) + plt.gca().text(0.05, 0.95, textstr, transform=plt.gca().transAxes, fontsize=14, verticalalignment='top', bbox=props) + plt.grid(True) + plt.savefig('plot2.png', transparent = True) + plt.close() + return + def plot_srcc_MAE(scratch, BT, s_std, BT_std, axis, tag, folder, tag2, tag3=None): '''Plots the SRCC and MAE values for the scratch and BT models and saves the plot in the folder. Tag1: Temperature @@ -126,4 +148,90 @@ def plot_srcc_MAE(scratch, BT, s_std, BT_std, axis, tag, folder, tag2, tag3=None plt.legend() os.makedirs(folder, exist_ok=True) plt.savefig(f"{folder}/mae_T{tag}.png") - plt.close() \ No newline at end of file + plt.close() + +def plot_srcc_MAE_2(scratch, s_std, axis, tag, folder, tag2, tag3=None): + plt.rcParams.update({'font.size': 12}) + if tag2 == 'srcc': + plt.plot(axis, scratch, label='random weights', color='red', linestyle='dashed', marker='o') + plt.fill_between(axis, scratch - s_std, scratch + s_std, color='red', alpha=0.2) + # plt.plot(axis, BT, label='weights transfered from visc ANN', color='blue', linestyle='dashed', marker='o') + # plt.fill_between(axis, BT - BT_std, BT + BT_std, color='blue', alpha=0.2) + plt.xlabel('Number of data points') + if tag3 == 'avg': + plt.ylabel('Average SRCC') + plt.title(f'Average SRCC vs Number of train data points at ' + f'$T_{tag}$') + folder = os.path.join(folder, 'srcc_avg_plots') + else: + plt.ylabel('SRCC') + plt.title(f'SRCC vs Number of train data points at ' + f'$T_{tag}$') + folder = os.path.join(folder, 'srcc_plots') + plt.legend() + os.makedirs(folder, exist_ok=True) + plt.savefig(f"{folder}/srcc_T{tag}.png") + plt.close() + elif tag2 == 'mae': + plt.plot(axis, scratch, label='random weights', color='red', linestyle='dashed', marker='o') + plt.fill_between(axis, scratch - s_std, scratch + s_std, color='red', alpha=0.2) + # plt.plot(axis, BT, label='weights transfered from visc ANN', color='blue', linestyle='dashed', marker='o') + # plt.fill_between(axis, BT - BT_std, BT + BT_std, color='blue', alpha=0.2) + plt.xlabel('Number of data points') + if tag3 == 'avg': + plt.ylabel('Average MAE') + plt.title(f'Average MAE vs Number of train data points at ' + f'$T_{tag}$') + folder = os.path.join(folder, 'mae_avg_plots') + else: + plt.ylabel('MAE') + plt.title(f'MAE vs Number of train data points at ' + f'$T_{tag}$') + folder = os.path.join(folder, 'mae_plots') + plt.legend() + os.makedirs(folder, exist_ok=True) + plt.savefig(f"{folder}/mae_T{tag}.png") + plt.close() + +def plot_srcc_MAE_3(scratch, BT, s_std, BT_std, axis2, scratch_2, BT_2, s_std_2, BT_std_2, axis, tag, folder, tag2, tag3=None): + plt.rcParams.update({'font.size': 12}) + if tag2 == 'srcc': + plt.plot(axis2, scratch, label='random weights', color='red', linestyle='dashed', marker='o') + plt.fill_between(axis2, scratch - s_std, scratch + s_std, color='red', alpha=0.2) + plt.plot(axis2, BT, label='weights transfered from visc ANN', color='blue', linestyle='dashed', marker='o') + plt.fill_between(axis2, BT - BT_std, BT + BT_std, color='blue', alpha=0.2) + plt.plot(axis, scratch_2, label='random weights - GNN', color='green', linestyle='dashed', marker='o') + plt.fill_between(axis, scratch_2 - s_std_2, scratch_2 + s_std_2, color='green', alpha=0.2) + plt.plot(axis, BT_2, label='weights transfered from visc GNN', color='orange', linestyle='dashed', marker='o') + plt.fill_between(axis, BT_2 - BT_std_2, BT_2 + BT_std_2, color='orange', alpha=0.2) + plt.xlabel('Number of data points') + if tag3 == 'avg': + plt.ylabel('Average SRCC') + plt.title(f'Average SRCC vs Number of train data points at ' + f'$T_{tag}$') + folder = os.path.join(folder, 'srcc_avg_plots') + else: + plt.ylabel('SRCC') + plt.title(f'SRCC vs Number of train data points at ' + f'$T_{tag}$') + folder = os.path.join(folder, 'srcc_plots') + plt.legend() + os.makedirs(folder, exist_ok=True) + plt.savefig(f"{folder}/srcc_T{tag}.png") + plt.close() + elif tag2 == 'mae': + plt.plot(axis2, scratch, label='random weights', color='red', linestyle='dashed', marker='o') + plt.fill_between(axis2, scratch - s_std, scratch + s_std, color='red', alpha=0.2) + plt.plot(axis2, BT, label='weights transfered from visc ANN', color='blue', linestyle='dashed', marker='o') + plt.fill_between(axis2, BT - BT_std, BT + BT_std, color='blue', alpha=0.2) + plt.plot(axis, scratch_2, label='random weights - GNN', color='green', linestyle='dashed', marker='o') + plt.fill_between(axis, scratch_2 - s_std_2, scratch_2 + s_std_2, color='green', alpha=0.2) + plt.plot(axis, BT_2, label='weights transfered from visc GNN', color='orange', linestyle='dashed', marker='o') + plt.fill_between(axis, BT_2 - BT_std_2, BT_2 + BT_std_2, color='orange', alpha=0.2) + plt.xlabel('Number of data points') + if tag3 == 'avg': + plt.ylabel('Average MAE') + plt.title(f'Average MAE vs Number of train data points at ' + f'$T_{tag}$') + folder = os.path.join(folder, 'mae_avg_plots') + else: + plt.ylabel('MAE') + plt.title(f'MAE vs Number of train data points at ' + f'$T_{tag}$') + folder = os.path.join(folder, 'mae_plots') + plt.legend() + os.makedirs(folder, exist_ok=True) + plt.savefig(f"{folder}/mae_T{tag}.png") + plt.close()