Skip to content

Commit db61c53

Browse files
authored
feat (ai/core): middleware support (#2759)
1 parent 7ee8d32 commit db61c53

23 files changed

+886
-7
lines changed

‎.changeset/many-yaks-relate.md‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (ai/core): middleware support

‎content/docs/03-ai-sdk-core/40-provider-management.mdx‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description: Learn how to work with multiple providers
55

66
# Provider Management
77

8-
<Note>Provider management is an experimental feature.</Note>
8+
<Note type="warning">Provider management is an experimental feature.</Note>
99

1010
When you work with multiple providers and models, it is often desirable to manage them in a central place
1111
and access the models through simple string ids.
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
---
2+
title: Language Model Middleware
3+
description: Learn how to use middleware to enhance the behavior of language models
4+
---
5+
6+
# Language Model Middleware
7+
8+
<Note type="warning">
9+
Language model middleware is an experimental feature.
10+
</Note>
11+
12+
Language model middleware is a way to enhance the behavior of language models
13+
by intercepting and modifying the calls to the language model.
14+
15+
It can be used to add features like guardrails, RAG, caching, and logging
16+
in a language model agnostic way. Such middleware can be developed and
17+
distributed independently from the language models that they are applied to.
18+
19+
## Using Language Model Middleware
20+
21+
You can use language model middleware with the `wrapLanguageModel` function.
22+
It takes a language model and a language model middleware and returns a new
23+
language model that incorporates the middleware.
24+
25+
```ts
26+
import { experimental_wrapLanguageModel as wrapLanguageModel } from 'ai';
27+
28+
const wrappedLanguageModel = wrapLanguageModel({
29+
model: yourModel,
30+
middleware: yourLanguageModelMiddleware,
31+
});
32+
```
33+
34+
The wrapped language model can be used just like any other language model, e.g. in `streamText`:
35+
36+
```ts highlight="2"
37+
const result = await streamText({
38+
model: wrappedLanguageModel,
39+
prompt: 'What cities are in the United States?',
40+
});
41+
```
42+
43+
## Implementing Language Model Middleware
44+
45+
<Note>
46+
Implementing language model middleware is advanced functionality and requires
47+
a solid understanding of the [language model
48+
specification](https://github.com/vercel/ai/blob/main/packages/provider/src/language-model/v1/language-model-v1.ts).
49+
</Note>
50+
51+
You can implement any of the following three function to modify the behavior of the language model:
52+
53+
1. `transformParams`: Transforms the parameters before they are passed to the language model, for both `doGenerate` and `doStream`.
54+
2. `wrapGenerate`: Wraps the `doGenerate` method of the [language model](https://github.com/vercel/ai/blob/main/packages/provider/src/language-model/v1/language-model-v1.ts).
55+
You can modify the parameters, call the language model, and modify the result.
56+
3. `wrapStream`: Wraps the `doStream` method of the [language model](https://github.com/vercel/ai/blob/main/packages/provider/src/language-model/v1/language-model-v1.ts).
57+
You can modify the parameters, call the language model, and modify the result.
58+
59+
Here are some examples of how to implement language model middleware:
60+
61+
## Examples
62+
63+
<Note>
64+
These examples are not meant to be used in production. They are just to show
65+
how you can use middleware to enhance the behavior of language models.
66+
</Note>
67+
68+
### Logging
69+
70+
This example shows how to log the parameters and generated text of a language model call.
71+
72+
```ts
73+
import type {
74+
Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware,
75+
LanguageModelV1StreamPart,
76+
} from 'ai';
77+
78+
export const yourLogMiddleware: LanguageModelV1Middleware = {
79+
wrapGenerate: async ({ doGenerate, params }) => {
80+
console.log('doGenerate called');
81+
console.log(`params: ${JSON.stringify(params, null, 2)}`);
82+
83+
const result = await doGenerate();
84+
85+
console.log('doGenerate finished');
86+
console.log(`generated text: ${result.text}`);
87+
88+
return result;
89+
},
90+
91+
wrapStream: async ({ doStream, params }) => {
92+
console.log('doStream called');
93+
console.log(`params: ${JSON.stringify(params, null, 2)}`);
94+
95+
const { stream, ...rest } = await doStream();
96+
97+
let generatedText = '';
98+
99+
const transformStream = new TransformStream<
100+
LanguageModelV1StreamPart,
101+
LanguageModelV1StreamPart
102+
>({
103+
transform(chunk, controller) {
104+
if (chunk.type === 'text-delta') {
105+
generatedText += chunk.textDelta;
106+
}
107+
108+
controller.enqueue(chunk);
109+
},
110+
111+
flush() {
112+
console.log('doStream finished');
113+
console.log(`generated text: ${generatedText}`);
114+
},
115+
});
116+
117+
return {
118+
stream: stream.pipeThrough(transformStream),
119+
...rest,
120+
};
121+
},
122+
};
123+
```
124+
125+
### Caching
126+
127+
This example shows how to build a simple cache for the generated text of a language model call.
128+
129+
```ts
130+
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';
131+
132+
const cache = new Map<string, any>();
133+
134+
export const yourCacheMiddleware: LanguageModelV1Middleware = {
135+
wrapGenerate: async ({ doGenerate, params }) => {
136+
const cacheKey = JSON.stringify(params);
137+
138+
if (cache.has(cacheKey)) {
139+
return cache.get(cacheKey);
140+
}
141+
142+
const result = await doGenerate();
143+
144+
cache.set(cacheKey, result);
145+
146+
return result;
147+
},
148+
149+
// here you would implement the caching logic for streaming
150+
};
151+
```
152+
153+
### Retrieval Augmented Generation (RAG)
154+
155+
This example shows how to use RAG as middleware.
156+
157+
<Note>
158+
Helper functions like `getLastUserMessageText` and `findSources` are not part
159+
of the AI SDK. They are just used in this example to illustrate the concept of
160+
RAG.
161+
</Note>
162+
163+
```ts
164+
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';
165+
166+
export const yourRagMiddleware: LanguageModelV1Middleware = {
167+
transformParams: async ({ params }) => {
168+
const lastUserMessageText = getLastUserMessageText({
169+
prompt: params.prompt,
170+
});
171+
172+
if (lastUserMessageText == null) {
173+
return params; // do not use RAG (send unmodified parameters)
174+
}
175+
176+
const instruction =
177+
'Use the following information to answer the question:\n' +
178+
findSources({ text: lastUserMessageText })
179+
.map(chunk => JSON.stringify(chunk))
180+
.join('\n');
181+
182+
return addToLastUserMessage({ params, text: instruction });
183+
},
184+
};
185+
```
186+
187+
### Guardrails
188+
189+
Guard rails are a way to ensure that the generated text of a language model call
190+
is safe and appropriate. This example shows how to use guardrails as middleware.
191+
192+
```ts
193+
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';
194+
195+
export const yourGuardrailMiddleware: LanguageModelV1Middleware = {
196+
wrapGenerate: async ({ doGenerate }) => {
197+
const { text, ...rest } = await doGenerate();
198+
199+
// filtering approach, e.g. for PII or other sensitive information:
200+
const cleanedText = text?.replace(/badword/g, '<REDACTED>');
201+
202+
return { text: cleanedText, ...rest };
203+
},
204+
205+
// here you would implement the guardrail logic for streaming
206+
// Note: streaming guardrails are difficult to implement, because
207+
// you do not know the full content of the stream until it's finished.
208+
};
209+
```

‎content/docs/07-reference/ai-sdk-core/40-provider-registry.mdx‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
---
2-
title: experimental_createProviderRegistry
2+
title: createProviderRegistry
33
description: Registry for managing multiple providers and models (API Reference)
44
---
55

6-
# `experimental_createProviderRegistry()`
6+
# `createProviderRegistry()`
77

8-
<Note>Provider management is an experimental feature.</Note>
8+
<Note type="warning">Provider management is an experimental feature.</Note>
99

1010
When you work with multiple providers and models, it is often desirable to manage them
1111
in a central place and access the models through simple string ids.

‎content/docs/07-reference/ai-sdk-core/42-custom-provider.mdx‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
---
2-
title: experimental_customProvider
2+
title: customProvider
33
description: Custom provider that uses models from a different provider (API Reference)
44
---
55

6-
# `experimental_customProvider()`
6+
# `customProvider()`
77

8-
<Note>Provider management is an experimental feature.</Note>
8+
<Note type="warning">Provider management is an experimental feature.</Note>
99

1010
With a custom provider, you can map ids to any model.
1111
This allows you to set up custom model configurations, alias names, and more.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
---
2+
title: wrapLanguageModel
3+
description: Function for wrapping a language model with middleware (API Reference)
4+
---
5+
6+
# `wrapLanguageModel()`
7+
8+
<Note type="warning">
9+
Language model middleware is an experimental feature.
10+
</Note>
11+
12+
The `experimental_wrapLanguageModel` function provides a way to enhance the behavior of language models
13+
by wrapping them with middleware.
14+
See [Language Model Middleware](/docs/ai-sdk-core/middleware) for more information on middleware.
15+
16+
```ts
17+
import { experimental_wrapLanguageModel as wrapLanguageModel } from 'ai';
18+
19+
const wrappedLanguageModel = wrapLanguageModel({
20+
model: yourModel,
21+
middleware: yourLanguageModelMiddleware,
22+
});
23+
```
24+
25+
## Import
26+
27+
<Snippet
28+
text={`import { experimental_wrapLanguageModel as wrapLanguageModel } from "ai"`}
29+
prompt={false}
30+
/>
31+
32+
## API Signature
33+
34+
### Parameters
35+
36+
<PropertiesTable
37+
content={[
38+
{
39+
name: 'model',
40+
type: 'LanguageModelV1',
41+
description: 'The original LanguageModelV1 instance to be wrapped.',
42+
},
43+
{
44+
name: 'middleware',
45+
type: 'Experimental_LanguageModelV1Middleware',
46+
description: 'The middleware to be applied to the language model.',
47+
},
48+
{
49+
name: 'modelId',
50+
type: 'string',
51+
description:
52+
"Optional custom model ID to override the original model's ID.",
53+
},
54+
{
55+
name: 'providerId',
56+
type: 'string',
57+
description:
58+
"Optional custom provider ID to override the original model's provider.",
59+
},
60+
]}
61+
/>
62+
63+
### Returns
64+
65+
A new `LanguageModelV1` instance with middleware applied.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
---
2+
title: LanguageModelV1Middleware
3+
description: Middleware for enhancing language model behavior (API Reference)
4+
---
5+
6+
# `LanguageModelV1Middleware`
7+
8+
<Note type="warning">
9+
Language model middleware is an experimental feature.
10+
</Note>
11+
12+
Language model middleware provides a way to enhance the behavior of language models
13+
by intercepting and modifying the calls to the language model. It can be used to add
14+
features like guardrails, RAG, caching, and logging in a language model agnostic way.
15+
16+
See [Language Model Middleware](/docs/ai-sdk-core/middleware) for more information.
17+
18+
## Import
19+
20+
<Snippet
21+
text={`import { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from "ai"`}
22+
prompt={false}
23+
/>
24+
25+
## API Signature
26+
27+
<PropertiesTable
28+
content={[
29+
{
30+
name: 'transformParams',
31+
type: '({ type: "generate" | "stream", params: LanguageModelV1CallOptions }) => Promise<LanguageModelV1CallOptions>',
32+
description:
33+
'Transforms the parameters before they are passed to the language model.',
34+
},
35+
{
36+
name: 'wrapGenerate',
37+
type: '({ doGenerate: DoGenerateFunction, params: LanguageModelV1CallOptions, model: LanguageModelV1 }) => Promise<DoGenerateResult>',
38+
description: 'Wraps the generate operation of the language model.',
39+
},
40+
{
41+
name: 'wrapStream',
42+
type: '({ doStream: DoStreamFunction, params: LanguageModelV1CallOptions, model: LanguageModelV1 }) => Promise<DoStreamResult>',
43+
description: 'Wraps the stream operation of the language model.',
44+
},
45+
]}
46+
/>
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import { LanguageModelV1CallOptions } from 'ai';
2+
3+
export function addToLastUserMessage({
4+
text,
5+
params,
6+
}: {
7+
text: string;
8+
params: LanguageModelV1CallOptions;
9+
}): LanguageModelV1CallOptions {
10+
const { prompt, ...rest } = params;
11+
12+
const lastMessage = prompt.at(-1);
13+
14+
if (lastMessage?.role !== 'user') {
15+
return params;
16+
}
17+
18+
return {
19+
...rest,
20+
prompt: [
21+
...prompt.slice(0, -1),
22+
{
23+
...lastMessage,
24+
content: [{ type: 'text', text }, ...lastMessage.content],
25+
},
26+
],
27+
};
28+
}

0 commit comments

Comments
 (0)