Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not review - Interpret qa individual feature importance beta #2146

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/core-ui/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ export * from "./lib/Interfaces/IModelExplanationData";
export * from "./lib/Interfaces/IConfusionMatrixData";
export * from "./lib/Interfaces/IVisionModelExplanationData";
export * from "./lib/Interfaces/IWeightedDropdownContext";
export * from "./lib/Interfaces/ITokenDropdownContext";
export * from "./lib/Interfaces/IFilter";
export * from "./lib/Interfaces/IPreBuiltFilter";
export * from "./lib/Interfaces/ICohort";
Expand Down
17 changes: 17 additions & 0 deletions libs/core-ui/src/lib/Interfaces/ITokenDropdownContext.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { IComboBox, IComboBoxOption } from "@fluentui/react";
import React from "react";

export type TokenOption = number


export interface ITokenDropdownContext {
options: IComboBoxOption[];
selectedKey: TokenOption;
onSelection: (
event: React.FormEvent<IComboBox>,
item?: IComboBoxOption
) => void;
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ export enum RadioKeys {
Neg = "neg"
}

export enum QAExplanationType {
Start = "start",
End = "end"
}

export class Utils {
public static argsort(toSort: number[]): number[] {
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import {
Text
} from "@fluentui/react";
import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui";
import { ClassImportanceWeights } from "@responsible-ai/interpret";
import { ClassImportanceWeights, TokenImportance } from "@responsible-ai/interpret";
import { localization } from "@responsible-ai/localization";
import React from "react";

import { RadioKeys, Utils } from "../../CommonUtils";
import { RadioKeys, QAExplanationType, Utils } from "../../CommonUtils";
import { ITextExplanationViewProps } from "../../Interfaces/IExplanationViewProps";
import { BarChart } from "../BarChart/BarChart";
import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend";
Expand All @@ -30,8 +30,13 @@ export interface ITextExplanationViewState {
maxK: number;
topK: number;
radio: string;
qaRadio?: string;
importances: number[];
singleTokenImportances: number[];
selectedToken: number;
tokenIndexes: number[];
text: string[];
isQA: boolean // temporal flag for identifying QA
}

const options: IChoiceGroupOption[] = [
Expand All @@ -43,6 +48,15 @@ const options: IChoiceGroupOption[] = [
{ key: RadioKeys.Neg, text: localization.InterpretText.View.negButton }
];

const qaOptions: IChoiceGroupOption[] = [
/*
* Creates the choices for the QA prediction radio button(local testing)
* TODO: move text under localization.InterpretText.View
*/
{ key: QAExplanationType.Start, text: "STARTING POSITION" },
{ key: QAExplanationType.End, text: "ENDING POSITION" },
];

const componentStackTokens: IStackTokens = {
childrenGap: "m",
padding: "m"
Expand All @@ -59,19 +73,28 @@ export class TextExplanationView extends React.PureComponent<
* Initializes the text view with its state
*/
super(props);
const weightVector = this.props.selectedWeightVector;
const importances = this.computeImportancesForWeightVector(

const isQA = false; // FIXME: temporally hardcode the flag, should use prop instead

const importances = this.computeImportancesForAllTokens(
this.props.dataSummary.localExplanations,
weightVector
);

const selectedToken = 0; // default to the first token
const singleTokenImportances = this.getImportanceForSingleToken(selectedToken);
const maxK = this.calculateMaxKImportances(importances);
const topK = this.calculateTopKImportances(importances);
this.state = {
importances,
importances: importances,
singleTokenImportances: singleTokenImportances,
selectedToken: selectedToken,
tokenIndexes: Array.from(this.props.dataSummary.text, (_, index) => index),
maxK,
radio: RadioKeys.All,
qaRadio: "starting",
text: this.props.dataSummary.text,
topK
topK,
isQA
};
}

Expand All @@ -81,16 +104,37 @@ export class TextExplanationView extends React.PureComponent<
this.props.dataSummary.localExplanations !==
prevProps.dataSummary.localExplanations
) {
this.updateImportances(this.props.selectedWeightVector);
if (this.state.isQA) {
this.setState({ //update token dropdown
tokenIndexes: Array.from(this.props.dataSummary.text, (_, index) => index),
selectedToken: 0
},
() => {
this.updateTokenImportances();
this.updateSingleTokenImportances();
})
} else {
this.updateImportances(this.props.selectedWeightVector);
}
}
}

public render(): React.ReactNode {
const classNames = textExplanationDashboardStyles();
const qaDescription = 'The left text box and the bar chart display the predictions of the model.' +
'The right textbox shows the feature importance associated with a selected token. Positive feature ' +
'importances represent the extent that the words were important towards marking the selected token' +
'as the starting/ending position of the answer.'

return (
<Stack>
<Stack tokens={componentStackTokens} horizontal>
<Text>{localization.InterpretText.View.legendText}</Text>
{
this.state.isQA?
<Text>{qaDescription}</Text> :
<Text>{localization.InterpretText.View.legendText}</Text>
}

</Stack>
<Stack tokens={componentStackTokens} horizontal>
<Stack.Item grow disableShrink>
Expand All @@ -103,6 +147,8 @@ export class TextExplanationView extends React.PureComponent<
</Stack.Item>
<Stack.Item grow className={classNames.chartRight}>
<Stack tokens={componentStackTokens}>

{ !this.state.isQA && ( // classfication
<Stack.Item>
<Text variant={"xLarge"}>
{localization.InterpretText.View.label +
Expand All @@ -113,9 +159,23 @@ export class TextExplanationView extends React.PureComponent<
)}
</Text>
</Stack.Item>
)}

{ this.state.isQA && ( // select starting/ending for QA
<Stack.Item id="TextChoiceGroup">
<ChoiceGroup
defaultSelectedKey="starting"
options={qaOptions}
onChange={this.switchQAprediction}
required
/>
</Stack.Item>
)}

<Stack.Item>
<Label>{localization.InterpretText.View.importantWords}</Label>
</Stack.Item>

<Stack.Item id="TextTopKSlider">
<Slider
min={1}
Expand All @@ -126,14 +186,28 @@ export class TextExplanationView extends React.PureComponent<
onChange={this.setTopK}
/>
</Stack.Item>
<Stack.Item>

{ this.state.isQA?
(<Stack.Item>
<TokenImportance
onTokenChange={this.onSelectedTokenChange}
selectedToken={this.state.selectedToken}
tokenOptions={this.state.tokenIndexes}
tokenLabels={this.state.text}
/>
</Stack.Item>)
:
(<Stack.Item>
<ClassImportanceWeights
onWeightChange={this.onWeightVectorChange}
selectedWeightVector={this.props.selectedWeightVector}
weightOptions={this.props.weightOptions}
weightLabels={this.props.weightLabels}
/>
</Stack.Item>
)
}

{this.props.selectedWeightVector !== WeightVectors.AbsAvg && (
<Stack.Item id="TextChoiceGroup">
<ChoiceGroup
Expand All @@ -147,6 +221,7 @@ export class TextExplanationView extends React.PureComponent<
</Stack>
</Stack.Item>
</Stack>

<Stack tokens={componentStackTokens} horizontal>
<Stack.Item
align="stretch"
Expand All @@ -161,10 +236,31 @@ export class TextExplanationView extends React.PureComponent<
radio={this.state.radio}
/>
</Stack.Item>

{ this.state.isQA && (
<Stack.Item
align="stretch"
grow
disableShrink
className={classNames.textHighlighting}
>
<TextHighlighting
text={this.state.text}
localExplanations={this.state.singleTokenImportances}
topK={
// keep all importances for single token(set topK to length)
this.state.singleTokenImportances.length
}
radio={this.state.radio}
/>
</Stack.Item>
)}

<Stack.Item align="end">
<TextFeatureLegend />
</Stack.Item>
</Stack>

</Stack>
);
}
Expand All @@ -174,11 +270,21 @@ export class TextExplanationView extends React.PureComponent<
this.props.onWeightChange(weightOption);
};

private onSelectedTokenChange = (newIndex: number): void => {

this.setState(
{ selectedToken: newIndex },
() => {
this.updateSingleTokenImportances();
});
};

private updateImportances(weightOption: WeightVectorOption): void {
const importances = this.computeImportancesForWeightVector(
this.props.dataSummary.localExplanations,
weightOption
);

const topK = this.calculateTopKImportances(importances);
const maxK = this.calculateMaxKImportances(importances);
this.setState({
Expand All @@ -189,6 +295,28 @@ export class TextExplanationView extends React.PureComponent<
});
}


// for QA
private updateTokenImportances(): void {

const importances = this.computeImportancesForAllTokens(
this.props.dataSummary.localExplanations,
);
const topK = this.calculateTopKImportances(importances);
const maxK = this.calculateMaxKImportances(importances);
this.setState({
importances,
maxK,
topK,
text: this.props.dataSummary.text
});
}

private updateSingleTokenImportances(): void {
const singleTokenImportances = this.getImportanceForSingleToken(this.state.selectedToken);
this.setState({singleTokenImportances: singleTokenImportances});
}

private calculateTopKImportances(importances: number[]): number {
return Math.min(
MaxImportantWords,
Expand Down Expand Up @@ -222,6 +350,27 @@ export class TextExplanationView extends React.PureComponent<
);
}

private computeImportancesForAllTokens(
importances: number[][]
): number[] {
/*
* sum the tokens importance
* TODO: add base values?
*/

const sumImportances = importances[0].map((_, index) =>
importances.reduce((sum, row) => sum + row[index], 0)
);

return sumImportances;
}

private getImportanceForSingleToken(
index: number
): number[] {
return this.props.dataSummary.localExplanations.map(row => row[index]);
}

private setTopK = (newNumber: number): void => {
/*
* Changes the state of K
Expand All @@ -240,4 +389,18 @@ export class TextExplanationView extends React.PureComponent<
this.setState({ radio: item.key });
}
};

private switchQAprediction = (
_event?: React.FormEvent,
item?: IChoiceGroupOption
): void => {
/*
* switch to the target predictions(starting or ending)
* TODO: add logic for switching explanation data
*/
if (item?.key !== undefined) {
this.setState({ qaRadio: item.key });
}
};

}
1 change: 1 addition & 0 deletions libs/interpret/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export * from "./lib/MLIDashboard/NewExplanationDashboard";
export * from "./lib/MLIDashboard/Interfaces/IExplanationDashboardProps";
export * from "./lib/MLIDashboard/Interfaces/IStringsParam";
export * from "./lib/MLIDashboard/Controls/ClassImportanceWeights/ClassImportanceWeights";
export * from "./lib/MLIDashboard/Controls/TokenImportance/TokenImportance";
export * from "./lib/MLIDashboard/Controls/GlobalExplanationTab/GlobalExplanationTab";
export * from "./lib/MLIDashboard/Controls/GlobalExplanationTab/IGlobalSeries";
export * from "./lib/MLIDashboard/Controls/ModelPerformanceTab/ModelPerformanceTab";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { IStyle, mergeStyleSets, IProcessedStyleSet } from "@fluentui/react";

export interface ILabelWithCalloutStyles {
tokenLabel: IStyle;
tokenLabelText: IStyle;
}

export const tokenImportanceStyles: () => IProcessedStyleSet<ILabelWithCalloutStyles> =
() => {
return mergeStyleSets<ILabelWithCalloutStyles>({
tokenLabel: {
display: "inline-flex",
paddingTop: "10px"
},
tokenLabelText: {
fontWeight: "600",
paddingTop: "5px"
}
});
};
Loading