Skip to content

Commit c4107bf

Browse files
authored
add tests to cover connecting to durable object via websocket using the hibernatable apis (#495)
1 parent b7068e2 commit c4107bf

File tree

6 files changed

+159
-2
lines changed

6 files changed

+159
-2
lines changed

‎Cargo.lock‎

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎worker-sandbox/Cargo.toml‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ rand = "0.8.5"
3737
uuid = { version = "1.3.3", features = ["v4", "serde"] }
3838
serde-wasm-bindgen = "0.6.1"
3939
md5 = "0.7.0"
40+
tokio-stream = "0.1.15"
4041

4142
[dependencies.axum]
4243
version = "0.7"

‎worker-sandbox/src/alarm.rs‎

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::time::Duration;
2+
use tokio_stream::{StreamExt, StreamMap};
23

34
use worker::*;
45

@@ -69,3 +70,60 @@ pub async fn handle_put_raw(req: Request, env: Env, _data: SomeSharedData) -> Re
6970
let stub = id.get_stub()?;
7071
stub.fetch_with_request(req).await
7172
}
73+
74+
#[worker::send]
75+
pub async fn handle_websocket(_req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
76+
// Accept / handle a websocket connection
77+
let pair = WebSocketPair::new()?;
78+
let server = pair.server;
79+
server.accept()?;
80+
81+
// Connect to Durable Object via WS
82+
let namespace = env
83+
.durable_object("COUNTER")
84+
.expect("failed to get namespace");
85+
let stub = namespace.id_from_name("A")?.get_stub()?;
86+
let mut req = Request::new("https://fake-host/ws", Method::Get)?;
87+
req.headers_mut()?.set("upgrade", "websocket")?;
88+
89+
let res = stub.fetch_with_request(req).await?;
90+
let do_ws = res.websocket().expect("server did not accept websocket");
91+
do_ws.accept()?;
92+
93+
wasm_bindgen_futures::spawn_local(async move {
94+
let event_stream = server.events().expect("could not open stream");
95+
let do_event_stream = do_ws.events().expect("could not open stream");
96+
97+
let mut map = StreamMap::new();
98+
map.insert("client", event_stream);
99+
map.insert("durable", do_event_stream);
100+
101+
while let Some((key, event)) = map.next().await {
102+
match key {
103+
"client" => match event.expect("received error in websocket") {
104+
WebsocketEvent::Message(msg) => {
105+
if let Some(text) = msg.text() {
106+
do_ws.send_with_str(text).expect("could not relay text");
107+
}
108+
}
109+
WebsocketEvent::Close(_) => {
110+
let _res = do_ws.close(Some(1000), Some("client closed".to_string()));
111+
}
112+
},
113+
"durable" => match event.expect("received error in websocket") {
114+
WebsocketEvent::Message(msg) => {
115+
if let Some(text) = msg.text() {
116+
server.send_with_str(text).expect("could not relay text");
117+
}
118+
}
119+
WebsocketEvent::Close(_) => {
120+
let _res = server.close(Some(1000), Some("durable closed".to_string()));
121+
}
122+
},
123+
_ => unreachable!(),
124+
}
125+
}
126+
});
127+
128+
Response::from_websocket(pair.client)
129+
}

‎worker-sandbox/src/counter.rs‎

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,27 @@ impl DurableObject for Counter {
1919
}
2020
}
2121

22-
async fn fetch(&mut self, _req: Request) -> Result<Response> {
22+
async fn fetch(&mut self, req: Request) -> Result<Response> {
2323
if !self.initialized {
2424
self.initialized = true;
2525
self.count = self.state.storage().get("count").await.unwrap_or(0);
2626
}
2727

28+
if req.path().eq("/ws") {
29+
let pair = WebSocketPair::new()?;
30+
let server = pair.server;
31+
// accept websocket with hibernation api
32+
self.state.accept_web_socket(&server);
33+
server
34+
.serialize_attachment("hello")
35+
.expect("failed to serialize attachment");
36+
37+
return Ok(Response::empty()
38+
.unwrap()
39+
.with_status(101)
40+
.with_websocket(Some(pair.client)));
41+
}
42+
2843
self.count += 10;
2944
self.state.storage().put("count", self.count).await?;
3045

@@ -34,4 +49,36 @@ impl DurableObject for Counter {
3449
self.env.secret("SOME_SECRET")?.to_string()
3550
))
3651
}
52+
53+
async fn websocket_message(
54+
&mut self,
55+
ws: WebSocket,
56+
_message: WebSocketIncomingMessage,
57+
) -> Result<()> {
58+
let _attach: String = ws
59+
.deserialize_attachment()?
60+
.expect("websockets should have an attachment");
61+
// get and increment storage by 10
62+
let mut count: usize = self.state.storage().get("count").await.unwrap_or(0);
63+
count += 10;
64+
self.state.storage().put("count", count).await?;
65+
// send value to client
66+
ws.send_with_str(format!("{}", count))
67+
.expect("failed to send value to client");
68+
Ok(())
69+
}
70+
71+
async fn websocket_close(
72+
&mut self,
73+
_ws: WebSocket,
74+
_code: usize,
75+
_reason: String,
76+
_was_clean: bool,
77+
) -> Result<()> {
78+
Ok(())
79+
}
80+
81+
async fn websocket_error(&mut self, _ws: WebSocket, _error: Error) -> Result<()> {
82+
Ok(())
83+
}
3784
}

‎worker-sandbox/src/router.rs‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ pub fn make_router(data: SomeSharedData, env: Env) -> axum::Router {
107107
.route("/durable/alarm", get(handler!(alarm::handle_alarm)))
108108
.route("/durable/:id", get(handler!(alarm::handle_id)))
109109
.route("/durable/put-raw", get(handler!(alarm::handle_put_raw)))
110+
.route("/durable/websocket", get(handler!(alarm::handle_websocket)))
110111
.route("/var", get(handler!(request::handle_var)))
111112
.route("/secret", get(handler!(request::handle_secret)))
112113
.route("/kv/:key/:value", post(handler!(kv::handle_post_key_value)))
@@ -241,6 +242,7 @@ pub fn make_router<'a>(data: SomeSharedData) -> Router<'a, SomeSharedData> {
241242
.get_async("/durable/alarm", handler!(alarm::handle_alarm))
242243
.get_async("/durable/:id", handler!(alarm::handle_id))
243244
.get_async("/durable/put-raw", handler!(alarm::handle_put_raw))
245+
.get_async("/durable/websocket", handler!(alarm::handle_websocket))
244246
.get_async("/secret", handler!(request::handle_secret))
245247
.get_async("/var", handler!(request::handle_var))
246248
.post_async("/kv/:key/:value", handler!(kv::handle_post_key_value))
Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,46 @@
1-
import { describe, test, expect } from "vitest";
1+
import {describe, test, expect, vi} from "vitest";
22
import { mf } from "./mf";
3+
import {MessageEvent} from "miniflare";
34

45
describe("durable", () => {
56
test("put-raw", async () => {
67
const resp = await mf.dispatchFetch("https://fake.host/durable/put-raw");
78
expect(await resp.text()).toBe("ok");
89
});
10+
11+
test("websocket-to-durable", async () => {
12+
const resp = await mf.dispatchFetch("http://fake.host/durable/websocket", {
13+
headers: {
14+
upgrade: "websocket",
15+
},
16+
});
17+
expect(resp.webSocket).not.toBeNull();
18+
19+
const socket = resp.webSocket!;
20+
socket.accept();
21+
22+
const handlers = {
23+
messageHandler: (event: MessageEvent) => {
24+
expect(event.data).toMatch(/^10|20|30$/);
25+
},
26+
close(event: CloseEvent) {},
27+
};
28+
29+
const messageHandlerWrapper = vi.spyOn(handlers, "messageHandler");
30+
const closeHandlerWrapper = vi.spyOn(handlers, "messageHandler");
31+
socket.addEventListener("message", handlers.messageHandler);
32+
socket.addEventListener("close", handlers.close);
33+
34+
socket.send("hi, can you ++?");
35+
await new Promise((resolve) => setTimeout(resolve, 500));
36+
expect(messageHandlerWrapper).toHaveBeenCalledTimes(1);
37+
38+
socket.send("hi again, more ++?");
39+
await new Promise((resolve) => setTimeout(resolve, 500));
40+
expect(messageHandlerWrapper).toHaveBeenCalledTimes(2);
41+
42+
socket.close();
43+
expect(closeHandlerWrapper).toBeCalled();
44+
});
945
});
46+

0 commit comments

Comments
 (0)