From 67bfb0b9f835f772f3af2aa176745c6b6135373a Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Wed, 21 Aug 2024 15:34:46 +0800 Subject: [PATCH] feat(ui/circuit): show attention pattern --- ui/src/components/model/circuit.tsx | 6 ++++++ ui/src/types/model.ts | 1 + 2 files changed, 7 insertions(+) diff --git a/ui/src/components/model/circuit.tsx b/ui/src/components/model/circuit.tsx index 90f38e91..f5e21cb9 100644 --- a/ui/src/components/model/circuit.tsx +++ b/ui/src/components/model/circuit.tsx @@ -69,6 +69,10 @@ const NodeInfo = ({ node }: { node: Node }) => {
{node.data.tracingNode.key}
Score:
{node.data.tracingNode.activation.toFixed(3)}
+
Pattern:
+
+ {node.data.tracingNode.pattern.toFixed(3)} +
); @@ -211,6 +215,8 @@ export const CircuitViewer = memo( const getNodeClassNames = useCallback((node: TracingNode) => { if (node.type === "feature") { return cn(getAccentClassname(node.activation, node.maxActivation, "border")); + } else if (node.type === "attn-score") { + return cn(getAccentClassname(node.pattern, 1, "border")); } return ""; }, []); diff --git a/ui/src/types/model.ts b/ui/src/types/model.ts index 63a4a85e..e08ec9bb 100644 --- a/ui/src/types/model.ts +++ b/ui/src/types/model.ts @@ -26,6 +26,7 @@ const AttnScoreNodeSchema = z.object({ query: z.number(), key: z.number(), activation: z.number(), + pattern: z.number(), }); export const TracingNodeSchema = z.discriminatedUnion("type", [