diff --git a/README.md b/README.md index fa74d900..f5f3bf4c 100644 --- a/README.md +++ b/README.md @@ -57,11 +57,11 @@ together. You can also add a `--visualize` flag to visualize the results of the elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled ``` -If you just do `elk plot`, it will plot the results from the most recent sweep. -If you want to plot a specific sweep, you can do so with: +If you just do `elk plot`, it will plot the results of AUROC from the most recent sweep. +If you want to plot a specific sweep, with a specific metric type, you can do so with: ```bash -elk plot {sweep_name} +elk plot {sweep_name} --metric acc_estimate ``` ## Caching diff --git a/elk/plotting/command.py b/elk/plotting/command.py index e79dc165..db80319d 100644 --- a/elk/plotting/command.py +++ b/elk/plotting/command.py @@ -22,6 +22,9 @@ class Plot: overwrite: bool = False """Whether to overwrite existing plots.""" + metric_type: str = "auroc_estimate" + """Name of metric to plot""" + def execute(self): root_dir = sweeps_dir() @@ -47,4 +50,4 @@ def execute(self): if self.overwrite: shutil.rmtree(sweep_path / "viz") - visualize_sweep(sweep_path) + visualize_sweep(sweep_path, self.metric_type) diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index fa183e5a..eb2a0f81 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -47,7 +47,7 @@ def render( shared_yaxes=True, vertical_spacing=0.1, x_title="Layer", - y_title="AUROC", + y_title=f"{sweep.metric_type}", ) color_map = dict(zip(ensembles, qualitative.Plotly)) @@ -56,7 +56,7 @@ def render( if with_transfer: # TODO write tests ensemble_data = ensemble_data.groupby( ["eval_dataset", "layer", "ensembling"], as_index=False - ).agg({"auroc_estimate": "mean"}) + ).agg({f"{sweep.metric_type}": "mean"}) else: ensemble_data = ensemble_data[ ensemble_data["eval_dataset"] == ensemble_data["train_dataset"] @@ -75,7 +75,7 @@ def render( fig.add_trace( go.Scatter( x=dataset_data["layer"], - y=dataset_data["auroc_estimate"], + y=dataset_data[f"{sweep.metric_type}"], mode="lines", name=ensemble, showlegend=False @@ -95,7 +95,7 @@ def render( legend=dict( title="Ensembling", ), - title=f"AUROC Trend: {self.model_name}", + title=f"{sweep.metric_type} Trend: {self.model_name}", ) if write: fig.write_image( @@ -114,7 +114,7 @@ class TransferEvalHeatmap: """Class for generating heatmaps for transfer evaluation results.""" layer: int - score_type: str = "auroc_estimate" + metric_type: str = "" ensembling: str = "full" def render(self, df: pd.DataFrame) -> go.Figure: @@ -129,7 +129,7 @@ def render(self, df: pd.DataFrame) -> go.Figure: model_name = df["eval_dataset"].iloc[0] # infer model name # TODO: validate pivot = pd.pivot_table( - df, values=self.score_type, index="eval_dataset", columns="train_dataset" + df, values=self.metric_type, index="eval_dataset", columns="train_dataset" ) fig = px.imshow(pivot, color_continuous_scale="Viridis", text_auto=True) @@ -137,7 +137,8 @@ def render(self, df: pd.DataFrame) -> go.Figure: fig.update_layout( xaxis_title="Train Dataset", yaxis_title="Transfer Dataset", - title=f"AUROC Score Heatmap: {model_name} | Layer {self.layer}", + title=f"{self.metric_type} Score Heatmap: {model_name} \ + | Layer {self.layer}", ) return fig @@ -145,11 +146,11 @@ def render(self, df: pd.DataFrame) -> go.Figure: @dataclass class TransferEvalTrend: - """Class for generating line plots for the trend of AUROC scores in transfer + """Class for generating line plots for the trend of metric scores in transfer evaluation.""" dataset_names: list[str] | None - score_type: str = "auroc_estimate" + metric_type: str = "" def render(self, df: pd.DataFrame) -> go.Figure: """Render the trend plot visualization. @@ -164,14 +165,14 @@ def render(self, df: pd.DataFrame) -> go.Figure: if self.dataset_names is not None: df = self._filter_transfer_datasets(df, self.dataset_names) pivot = pd.pivot_table( - df, values=self.score_type, index="layer", columns="eval_dataset" + df, values=self.metric_type, index="layer", columns="eval_dataset" ) fig = px.line(pivot, color_discrete_sequence=px.colors.qualitative.Plotly) fig.update_layout( xaxis_title="Layer", - yaxis_title="AUROC Score", - title=f"AUROC Score Trend: {model_name}", + yaxis_title=f"{self.metric_type} Score", + title=f"{self.metric_type} Score Trend: {model_name}", ) avg = pivot.mean(axis=1) @@ -244,7 +245,6 @@ def render_and_save( self, sweep: "SweepVisualization", dataset_names: list[str] | None = None, - score_type="auroc_estimate", ensembling="full", ) -> None: """Render and save the visualization for the model. @@ -252,9 +252,9 @@ def render_and_save( Args: sweep: The SweepVisualization instance. dataset_names: List of dataset names to include in the visualization. - score_type: The type of score to display. ensembling: The ensembling option to consider. """ + metric_type = sweep.metric_type df = self.df model_name = self.model_name layer_min, layer_max = df["layer"].min(), df["layer"].max() @@ -264,10 +264,10 @@ def render_and_save( for layer in range(layer_min, layer_max + 1): filtered = df[(df["layer"] == layer) & (df["ensembling"] == ensembling)] fig = TransferEvalHeatmap( - layer, score_type=score_type, ensembling=ensembling + layer, metric_type=metric_type, ensembling=ensembling ).render(filtered) fig.write_image(file=model_path / f"{layer}.png") - fig = TransferEvalTrend(dataset_names).render(df) + fig = TransferEvalTrend(dataset_names, metric_type=metric_type).render(df) fig.write_image(file=model_path / "transfer_eval_trend.png") @staticmethod @@ -288,6 +288,7 @@ class SweepVisualization: path: Path datasets: list[str] models: dict[str, ModelVisualization] + metric_type: str def model_names(self) -> list[str]: """Get the names of all models in the sweep. @@ -323,7 +324,7 @@ def _get_model_paths(sweep_path: Path) -> list[Path]: return folders @classmethod - def collect(cls, sweep_path: Path) -> "SweepVisualization": + def collect(cls, sweep_path: Path, metric_type: str) -> "SweepVisualization": """Collect the evaluation data for a sweep. Args: @@ -348,7 +349,9 @@ def collect(cls, sweep_path: Path) -> "SweepVisualization": } df = pd.concat([model.df for model in models.values()], ignore_index=True) datasets = list(df["eval_dataset"].unique()) - return cls(sweep_name, df, sweep_viz_path, datasets, models) + return cls( + sweep_name, df, sweep_viz_path, datasets, models, metric_type=metric_type + ) def render_and_save(self): """Render and save all visualizations for the sweep.""" @@ -368,14 +371,11 @@ def render_multiplots(self, write=False): for model in self.models ] - def render_table( - self, score_type="auroc_estimate", display=True, write=False - ) -> pd.DataFrame: + def render_table(self, display=True, write=False) -> pd.DataFrame: """Render and optionally write the score table. Args: layer: The layer number (from last layer) to include in the score table. - score_type: The type of score to include in the table. display: Flag indicating whether to display the table to stdout. write: Flag indicating whether to write the table to a file. @@ -387,7 +387,7 @@ def render_table( # For each model, we use the layer whose mean AUROC is the highest best_layers, model_dfs = [], [] for _, model_df in df.groupby("model_name"): - best_layer = model_df.groupby("layer").auroc_estimate.mean().argmax() + best_layer = model_df.groupby("layer")[self.metric_type].mean().argmax() best_layers.append(best_layer) model_dfs.append(model_df[model_df["layer"] == best_layer]) @@ -395,7 +395,7 @@ def render_table( pivot_table = pd.concat(model_dfs).pivot_table( index="eval_dataset", columns="model_name", - values=score_type, + values=self.metric_type, margins=True, margins_name="Mean", ) @@ -416,14 +416,14 @@ def render_table( console.print(table) if write: - pivot_table.to_csv(f"score_table_{score_type}.csv") + pivot_table.to_csv(f"score_table_{self.metric_type}.csv") return pivot_table -def visualize_sweep(sweep_path: Path): +def visualize_sweep(sweep_path: Path, metric_type: str): """Visualize a sweep by generating and saving the visualizations. Args: sweep_path: The path to the sweep data directory. """ - SweepVisualization.collect(sweep_path).render_and_save() + SweepVisualization.collect(sweep_path, metric_type).render_and_save() diff --git a/elk/training/sweep.py b/elk/training/sweep.py index a4e5c97a..6a58f643 100755 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -50,6 +50,9 @@ class Sweep: visualize: bool = False """Whether to generate visualizations of the results of the sweep.""" + metric_type: str = "auroc_estimate" + """Name of metric to plot""" + name: str | None = None # A bit of a hack to add all the command line arguments from Elicit @@ -176,4 +179,4 @@ def execute(self): eval.execute(highlight_color="green") if self.visualize: - visualize_sweep(sweep_dir) + visualize_sweep(sweep_dir, self.metric_type)