Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AUC chart #2171

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/core-ui/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export * from "./lib/util/getCommonStyles";
export * from "./lib/util/getCompositeFilterString";
export * from "./lib/util/getFeatureOptions";
export * from "./lib/util/getFilterBoundsArgs";
export * from "./lib/util/calculateAUCData";
export * from "./lib/util/calculateBoxData";
export * from "./lib/util/calculateConfusionMatrixData";
export * from "./lib/util/calculateLineData";
Expand Down
6 changes: 5 additions & 1 deletion libs/core-ui/src/lib/FeatureFlights.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
// Licensed under the MIT License.

export const dataBalanceExperienceFlight = "dataBalanceExperience";
export const aucChartExperienceFlight = "aucChartExperience";
export const featureFlightSeparator = "&";

// add more entries for new feature flights
export const featureFlights = [dataBalanceExperienceFlight];
export const featureFlights = [
dataBalanceExperienceFlight,
aucChartExperienceFlight
];

export function parseFeatureFlights(featureFlights?: string): string[] {
if (featureFlights) {
Expand Down
7 changes: 7 additions & 0 deletions libs/core-ui/src/lib/Interfaces/IAUCData.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

export interface IAUCData {
AUCData: number[][];
selectedLabels: string[];
}
1 change: 1 addition & 0 deletions libs/core-ui/src/lib/components/OverallMetricChartUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ export function generateMetricsList(
if (!chartProps) {
return [];
}
console.log(yAxisProperty);
hawestra marked this conversation as resolved.
Show resolved Hide resolved
if (chartProps.yAxis.property === cohortKey) {
const indexes = context.errorCohorts.map((errorCohort) =>
errorCohort.cohort.unwrap(JointDataset.IndexLabel)
Expand Down
1 change: 0 additions & 1 deletion libs/core-ui/src/lib/util/StatisticsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ const generateImageStats: (
}
];
};

export const generateMetrics: (
jointDataset: JointDataset,
selectionIndexes: number[][],
Expand Down
148 changes: 148 additions & 0 deletions libs/core-ui/src/lib/util/calculateAUCData.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { binarizeData, calculatePerClassROCData } from "./calculateAUCData";

describe("Test binarizeData", () => {
it("should binarize numbers", () => {
const result = binarizeData([1, 3, 4, 0], [0, 1, 2, 3, 4]);
expect(result).toEqual([
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0]
]);
});
it("should binarize strings", () => {
const result = binarizeData(
["one", "two", "three"],
["three", "one", "two"]
);
expect(result).toEqual([
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
]);
});
it("should binarize binary data", () => {
const result = binarizeData([1, 0, 1, 0], [0, 1]);
expect(result).toEqual([
[0, 1],
[1, 0],
[0, 1],
[1, 0]
]);
});
});
describe("Test calculatePerClassROCData", () => {
it("generate x,y data corresponding to fpr and tpr respectively", () => {
const result = calculatePerClassROCData(
[0.33, 0.32, 0.34, 0.29, 0.12, 0.41, 0.4, 0.39],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add zeros after the comma to get all the points closer together (0.032, 0.033, 0.034, for example), then it will not do it right, will it?

[0, 1, 1, 0, 1, 0, 0, 0]
);
expect(result).toEqual({
points: [
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 1 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 1, y: 0.6666666666666666 },
{ x: 0.8, y: 0.6666666666666666 },
{ x: 0.8, y: 0.6666666666666666 },
{ x: 0.8, y: 0.6666666666666666 },
{ x: 0.8, y: 0.3333333333333333 },
{ x: 0.6, y: 0.3333333333333333 },
{ x: 0.6, y: 0 },
{ x: 0.6, y: 0 },
{ x: 0.6, y: 0 },
{ x: 0.6, y: 0 },
{ x: 0.6, y: 0 },
{ x: 0.4, y: 0 },
{ x: 0.2, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 },
{ x: 0, y: 0 }
]
});
});
});
142 changes: 142 additions & 0 deletions libs/core-ui/src/lib/util/calculateAUCData.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { localization } from "@responsible-ai/localization";
import { SeriesOptionsType } from "highcharts";
import { range, unzip } from "lodash";

import { IDataset } from "../Interfaces/IDataset";

interface IPoint {
x: number;
y: number;
}
export interface IROCData {
points: IPoint[];
}

function getStaticROCData(): SeriesOptionsType[] {
return [
{
data: [
{ x: 0, y: 0 },
{ x: 0, y: 1 },
{ x: 1, y: 1 }
],
name: localization.Interpret.Charts.Ideal,
type: "line"
},
{
data: [
{ x: 0, y: 0 },
{ x: 1, y: 1 }
],
name: localization.Interpret.Charts.Random,
type: "line"
}
];
}

export function calculatePerClassROCData(
probabilityY: number[],
binY: number[]
): IROCData {
const rocData: IROCData = {
points: []
};
const thresholds = range(0, 1, 0.01);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not quite right, is it?
For example, if I have two points within a tick (let's say at 0.001 and 0.004) then none of my thresholds covers the case where they are on opposite sides of the decision boundary. This could very well be the best FPR/TPR combination....

So I would order the values first, and then always take thresholds as the middle between two values, e.g.,

values: [1,5,4,2,3]
ordered: [1,2,3,4,5]
thresholds: [-inf, 1.5, 2.5, 3.5, 4.5, inf]
so total of 6 points for ROC, but then you still need to calculate the convex hull of those points if I'm not mistaken. How are you avoiding hitting points below the convex hull btw?

let truePositives = 0;
let falsePositives = 0;
let trueNegatives = 0;
let falseNegatives = 0;

for (const threshold of thresholds) {
hawestra marked this conversation as resolved.
Show resolved Hide resolved
for (const [index, yProba] of probabilityY.entries()) {
// if the probability of predicting the positive label is greater than the
// threshold then it's a true positive.
// otherwise, it's a false positive
if (yProba < threshold) {
if (binY[index]) {
falseNegatives++;
} else {
trueNegatives++;
}
} else if (binY[index]) {
truePositives++;
} else {
falsePositives++;
}
}
addROCPoint(
truePositives,
falsePositives,
trueNegatives,
falseNegatives,
rocData
);
truePositives = falsePositives = trueNegatives = falseNegatives = 0;
hawestra marked this conversation as resolved.
Show resolved Hide resolved
}

return rocData;
}

function addROCPoint(
truePositives: number,
falsePositives: number,
trueNegatives: number,
falseNegatives: number,
rocData: IROCData
): void {
// prevent division by 0
const totalNegatives = trueNegatives + falsePositives;
const totalPositives = truePositives + falseNegatives;
const tpr = totalPositives === 0 ? 1 : truePositives / totalPositives;
const fpr = totalNegatives === 0 ? 1 : falsePositives / totalNegatives;
rocData.points.push({ x: fpr, y: tpr });
}

export function binarizeData(
yData: string[] | number[] | number[][],
classes: string[] | number[]
): number[][] {
// binarize labels in a one-vs-all fashion according to
const yBinData: number[][] = [];
for (const yDatum of yData) {
const binaryData = classes.map((c) => {
return c === yDatum ? 1 : 0;
});
yBinData.push(binaryData);
}
return yBinData;
}

// based on https://msdata.visualstudio.com/Vienna/_git/AzureMlCli?path=/src/azureml-metrics/azureml/metrics/_classification.py&version=GBmaster
export function calculateAUCData(dataset: IDataset): SeriesOptionsType[] {
if (!dataset.probability_y || !dataset.class_names) {
// TODO: show warning message
return [...getStaticROCData()];
}

// temporary, replace with dataset.classnames
const cNames = [0, 1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should only run for binary classification, right? So we probably need a check somewhere to disable the component otherwise.

For multiclass one could do one vs all but I don't know if anyone wants that.

Copy link
Contributor Author

@hawestra hawestra Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the studio, there's an auc chart for multiclass so i assumed we'd want that (binarizeData is supposed to handle this case), but i'll discuss with Minsoo tomorrow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably 1 vs all, so that makes sense.

const binTrueY = binarizeData(dataset.true_y, cNames);
console.log(binTrueY);
// transpose in order to group class data together
const perClassBinY = unzip(binTrueY);
const perClassProba = unzip(dataset.probability_y);
const data = [];
// loop through each class to calculate roc data per class
for (const [i, element] of perClassBinY.entries()) {
const classROCData = calculatePerClassROCData(perClassProba[i], element);
const classData = {
data: classROCData.points,
// TODO: check class_names length earlier ?
name: cNames ? cNames[i] : "",
type: "line"
};
data.push(classData);
}

const allData = [...data, ...getStaticROCData()];
return allData as SeriesOptionsType[];
}
5 changes: 4 additions & 1 deletion libs/localization/src/lib/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,9 @@
"rowIndex": "Row index",
"absoluteIndex": "Absolute index",
"xValue": "X-value",
"yValue": "Y-value"
"yValue": "Y-value",
"Ideal": "Ideal",
"Random": "Random"
},
"Cohort": {
"_cohort.comment": "a subset of the data is called a cohort",
Expand Down Expand Up @@ -1818,6 +1820,7 @@
"regressionDistributionPivotItem": "Target distribution",
"metricsVisualizationsPivotItem": "Metrics visualizations",
"confusionMatrixPivotItem": "Confusion matrix",
"AUCPivotItem": "AUC Chart",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but we'll probably need a little explainer somewhere. Especially because it's an acronym

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will discuss with Minsoo!

"disaggregatedAnalysisFeatureSelectionPlaceholder": "Select features to generate the feature-based analysis.",
"tableCountTooltip": "Cohort {0} contains {1} instances.",
"tableMetricTooltip": "The model's {0} on cohort {1} is {2}",
Expand Down
Loading