Skip to content

Commit e974123

Browse files
mscolnickmanzt
andauthored
Unify handling of widget state/buffers/bufferPaths (#7441)
This keeps the `state/buffer/bufferPaths` in sync when sending to/from the backend for the widget. This changes some typings and adds a few more tests. This is an improvement, but not the final state as we can hopefully simplify further. This does fix a few errors when sending back binary data. Prior work done by @manzt in 9426b13 --------- Co-authored-by: Trevor Manz <trevor.j.manz@gmail.com>
1 parent 89e15ea commit e974123

File tree

12 files changed

+1187
-225
lines changed

12 files changed

+1187
-225
lines changed

frontend/src/plugins/impl/anywidget/AnyWidgetPlugin.tsx

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
/* eslint-disable @typescript-eslint/no-explicit-any */
33

44
import type { AnyWidget, Experimental } from "@anywidget/types";
5-
import { get, isEqual, set } from "lodash-es";
5+
import { isEqual } from "lodash-es";
66
import { useEffect, useMemo, useRef } from "react";
7+
import useEvent from "react-use-event-hook";
78
import { z } from "zod";
89
import { MarimoIncomingMessageEvent } from "@/core/dom/events";
910
import { asRemoteURL } from "@/core/runtime/config";
@@ -17,10 +18,13 @@ import { createPlugin } from "@/plugins/core/builder";
1718
import { rpc } from "@/plugins/core/rpc";
1819
import type { IPluginProps } from "@/plugins/types";
1920
import {
20-
type Base64String,
21-
byteStringToBinary,
22-
typedAtob,
23-
} from "@/utils/json/base64";
21+
decodeFromWire,
22+
isWireFormat,
23+
serializeBuffersToBase64,
24+
type WireFormat,
25+
} from "@/utils/data-views";
26+
import { prettyError } from "@/utils/errors";
27+
import type { Base64String } from "@/utils/json/base64";
2428
import { Logger } from "@/utils/Logger";
2529
import { ErrorBanner } from "../common/error-banner";
2630
import { MODEL_MANAGER, Model } from "./model";
@@ -29,44 +33,56 @@ interface Data {
2933
jsUrl: string;
3034
jsHash: string;
3135
css?: string | null;
32-
bufferPaths?: (string | number)[][] | null;
33-
initialValue: T;
3436
}
3537

36-
type T = Record<string, any>;
38+
type T = Record<string, unknown>;
3739

38-
// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
3940
type PluginFunctions = {
40-
send_to_widget: <T>(req: { content?: any }) => Promise<null | undefined>;
41+
send_to_widget: <T>(req: {
42+
content: unknown;
43+
buffers: Base64String[];
44+
}) => Promise<null | undefined>;
4145
};
4246

43-
export const AnyWidgetPlugin = createPlugin<T>("marimo-anywidget")
47+
export const AnyWidgetPlugin = createPlugin<WireFormat<T>>("marimo-anywidget")
4448
.withData(
4549
z.object({
4650
jsUrl: z.string(),
4751
jsHash: z.string(),
4852
css: z.string().nullish(),
49-
bufferPaths: z
50-
.array(z.array(z.union([z.string(), z.number()])))
51-
.nullish(),
52-
initialValue: z.object({}).passthrough(),
5353
}),
5454
)
5555
.withFunctions<PluginFunctions>({
5656
send_to_widget: rpc
57-
.input(z.object({ content: z.any() }))
57+
.input(
58+
z.object({
59+
content: z.unknown(),
60+
buffers: z.array(z.string().transform((v) => v as Base64String)),
61+
}),
62+
)
5863
.output(z.null().optional()),
5964
})
6065
.renderer((props) => <AnyWidgetSlot {...props} />);
6166

62-
type Props = IPluginProps<T, Data, PluginFunctions>;
63-
64-
const AnyWidgetSlot = (props: Props) => {
65-
const { css, jsUrl, jsHash, bufferPaths } = props.data;
67+
const AnyWidgetSlot = (
68+
props: IPluginProps<WireFormat<T>, Data, PluginFunctions>,
69+
) => {
70+
const { css, jsUrl, jsHash } = props.data;
6671

72+
// Decode wire format { state, bufferPaths, buffers } to state with DataViews
6773
const valueWithBuffers = useMemo(() => {
68-
return resolveInitialValue(props.value, bufferPaths ?? []);
69-
}, [props.value, bufferPaths]);
74+
if (isWireFormat(props.value)) {
75+
const decoded = decodeFromWire(props.value);
76+
Logger.debug("AnyWidget decoded wire format:", {
77+
bufferPaths: props.value.bufferPaths,
78+
buffersCount: props.value.buffers?.length,
79+
decodedKeys: Object.keys(decoded),
80+
});
81+
return decoded;
82+
}
83+
Logger.warn("AnyWidget value is not wire format:", props.value);
84+
return props.value;
85+
}, [props.value]);
7086

7187
// JS is an ESM file with a render function on it
7288
// export function render({ model, el }) {
@@ -135,6 +151,12 @@ const AnyWidgetSlot = (props: Props) => {
135151
};
136152
}, [css, props.host]);
137153

154+
// Wrap setValue to serialize DataViews back to base64 before sending
155+
// Structure matches ipywidgets protocol: { state, bufferPaths, buffers }
156+
const wrappedSetValue = useEvent((partialValue: Partial<T>) =>
157+
props.setValue(serializeBuffersToBase64(partialValue)),
158+
);
159+
138160
if (error) {
139161
return <ErrorBanner error={error} />;
140162
}
@@ -162,6 +184,7 @@ const AnyWidgetSlot = (props: Props) => {
162184
key={key}
163185
{...props}
164186
widget={module.default}
187+
setValue={wrappedSetValue}
165188
value={valueWithBuffers}
166189
/>
167190
);
@@ -191,10 +214,19 @@ async function runAnyWidgetModule(
191214
const widget =
192215
typeof widgetDef === "function" ? await widgetDef() : widgetDef;
193216
await widget.initialize?.({ model, experimental });
194-
const unsub = await widget.render?.({ model, el, experimental });
195-
return () => {
196-
unsub?.();
197-
};
217+
try {
218+
const unsub = await widget.render?.({ model, el, experimental });
219+
return () => {
220+
unsub?.();
221+
};
222+
} catch (error) {
223+
Logger.error("Error rendering anywidget", error);
224+
el.classList.add("text-error");
225+
el.innerHTML = `Error rendering anywidget: ${prettyError(error)}`;
226+
return () => {
227+
// No-op
228+
};
229+
}
198230
}
199231

200232
function isAnyWidgetModule(mod: any): mod is { default: AnyWidget } {
@@ -218,6 +250,13 @@ function hasModelId(message: unknown): message is { model_id: string } {
218250
);
219251
}
220252

253+
interface Props
254+
extends Omit<IPluginProps<T, Data, PluginFunctions>, "setValue"> {
255+
widget: AnyWidget;
256+
value: T;
257+
setValue: (value: Partial<T>) => void;
258+
}
259+
221260
const LoadedSlot = ({
222261
value,
223262
setValue,
@@ -228,15 +267,9 @@ const LoadedSlot = ({
228267
}: Props & { widget: AnyWidget }) => {
229268
const htmlRef = useRef<HTMLDivElement>(null);
230269

270+
// value is already decoded from wire format
231271
const model = useRef<Model<T>>(
232-
new Model(
233-
// Merge the initial value with the current value
234-
// since we only send partial updates to the backend
235-
{ ...data.initialValue, ...value },
236-
setValue,
237-
functions.send_to_widget,
238-
getDirtyFields(value, data.initialValue),
239-
),
272+
new Model(value, setValue, functions.send_to_widget, new Set()),
240273
);
241274

242275
// Listen to incoming messages
@@ -289,16 +322,3 @@ export const visibleForTesting = {
289322
isAnyWidgetModule,
290323
getDirtyFields,
291324
};
292-
293-
export function resolveInitialValue(
294-
raw: Record<string, any>,
295-
bufferPaths: readonly (readonly (string | number)[])[],
296-
) {
297-
const out = structuredClone(raw);
298-
for (const bufferPath of bufferPaths) {
299-
const base64String: Base64String = get(raw, bufferPath);
300-
const bytes = byteStringToBinary(typedAtob(base64String));
301-
set(out, bufferPath, new DataView(bytes.buffer));
302-
}
303-
return out;
304-
}

frontend/src/plugins/impl/anywidget/__tests__/AnyWidgetPlugin.test.tsx

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest";
55
import { TestUtils } from "@/__tests__/test-helpers";
66
import type { UIElementId } from "@/core/cells/ids";
77
import { MarimoIncomingMessageEvent } from "@/core/dom/events";
8-
import {
9-
getDirtyFields,
10-
resolveInitialValue,
11-
visibleForTesting,
12-
} from "../AnyWidgetPlugin";
8+
import { getDirtyFields, visibleForTesting } from "../AnyWidgetPlugin";
139
import { Model } from "../model";
1410

1511
const { LoadedSlot } = visibleForTesting;
@@ -132,6 +128,7 @@ describe("LoadedSlot", () => {
132128
message: {
133129
method: "update",
134130
state: { count: 10 },
131+
buffer_paths: [],
135132
},
136133
buffers: [],
137134
},
@@ -183,55 +180,3 @@ describe("LoadedSlot", () => {
183180
});
184181
});
185182
});
186-
187-
describe("resolveInitialValue", () => {
188-
it("should convert base64 strings to DataView at specified paths", () => {
189-
const result = resolveInitialValue(
190-
{
191-
a: 10,
192-
b: "aGVsbG8=", // "hello" in base64
193-
c: [1, "d29ybGQ="], // "world" in base64
194-
d: {
195-
foo: "bWFyaW1vCg==", // "marimo" in base64
196-
baz: 20,
197-
},
198-
},
199-
[["b"], ["c", 1], ["d", "foo"]],
200-
);
201-
202-
expect(result).toMatchInlineSnapshot(`
203-
{
204-
"a": 10,
205-
"b": DataView [
206-
104,
207-
101,
208-
108,
209-
108,
210-
111,
211-
],
212-
"c": [
213-
1,
214-
DataView [
215-
119,
216-
111,
217-
114,
218-
108,
219-
100,
220-
],
221-
],
222-
"d": {
223-
"baz": 20,
224-
"foo": DataView [
225-
109,
226-
97,
227-
114,
228-
105,
229-
109,
230-
111,
231-
10,
232-
],
233-
},
234-
}
235-
`);
236-
});
237-
});

frontend/src/plugins/impl/anywidget/__tests__/model.test.ts

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
vi,
1010
} from "vitest";
1111
import { TestUtils } from "@/__tests__/test-helpers";
12+
import type { Base64String } from "@/utils/json/base64";
1213
import {
1314
type AnyWidgetMessage,
1415
handleWidgetMessage,
@@ -23,7 +24,10 @@ describe("Model", () => {
2324
// eslint-disable-next-line @typescript-eslint/no-explicit-any
2425
let onChange: (value: any) => void;
2526
// eslint-disable-next-line @typescript-eslint/no-explicit-any
26-
let sendToWidget: (req: { content?: any }) => Promise<null | undefined>;
27+
let sendToWidget: (req: {
28+
content: unknown;
29+
buffers: Base64String[];
30+
}) => Promise<null | undefined>;
2731

2832
beforeEach(() => {
2933
onChange = vi.fn();
@@ -72,7 +76,7 @@ describe("Model", () => {
7276
});
7377
});
7478

75-
it("should send all dirty fields", () => {
79+
it("should clear dirty fields after save", () => {
7680
model.set("foo", "new value");
7781
model.save_changes();
7882

@@ -83,14 +87,13 @@ describe("Model", () => {
8387
model.set("bar", 456);
8488
model.save_changes();
8589

90+
// After clearing, only the newly changed field is sent
8691
expect(onChange).toHaveBeenCalledWith({
87-
foo: "new value",
8892
bar: 456,
8993
});
9094
});
9195

92-
// Skip because we don't clear the dirty fields after save
93-
it.skip("should clear dirty fields after save", () => {
96+
it("should not call onChange when no dirty fields", () => {
9497
model.set("foo", "new value");
9598
model.save_changes();
9699
model.save_changes(); // Second save should not call onChange
@@ -144,21 +147,16 @@ describe("Model", () => {
144147
const callback = vi.fn();
145148
model.send({ test: true }, callback);
146149

147-
expect(sendToWidget).toHaveBeenCalledWith({ content: { test: true } });
150+
expect(sendToWidget).toHaveBeenCalledWith({
151+
content: {
152+
state: { test: true },
153+
bufferPaths: [],
154+
},
155+
buffers: [],
156+
});
148157
await TestUtils.nextTick(); // flush
149158
expect(callback).toHaveBeenCalledWith(null);
150159
});
151-
152-
it("should warn when buffers are provided", () => {
153-
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {
154-
// noop
155-
});
156-
model.send({ test: true }, null, [new ArrayBuffer(8)]);
157-
158-
expect(consoleSpy).toHaveBeenCalledWith(
159-
"buffers not supported in marimo anywidget.send",
160-
);
161-
});
162160
});
163161

164162
describe("widget_manager", () => {
@@ -228,7 +226,11 @@ describe("Model", () => {
228226
it("should handle update messages", () => {
229227
model.receiveCustomMessage({
230228
method: "update",
231-
state: { foo: "updated", bar: 789 },
229+
state: {
230+
foo: "updated",
231+
bar: 789,
232+
},
233+
buffer_paths: [],
232234
});
233235

234236
expect(model.get("foo")).toBe("updated");
@@ -333,7 +335,9 @@ describe("ModelManager", () => {
333335

334336
const updateMessage: AnyWidgetMessage = {
335337
method: "update",
336-
state: { count: 1 },
338+
state: {
339+
count: 1,
340+
},
337341
buffer_paths: [],
338342
};
339343

0 commit comments

Comments
 (0)