Skip to content

Commit 0f6bc4e

Browse files
authored
feat (ai/core): add embed function (#1575)
1 parent 1009594 commit 0f6bc4e

26 files changed

+963
-22
lines changed

‎.changeset/witty-beds-sell.md‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
'@ai-sdk/provider': patch
3+
'@ai-sdk/mistral': patch
4+
'@ai-sdk/openai': patch
5+
'ai': patch
6+
---
7+
8+
feat (ai/core): add embed function
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
---
2+
title: Embeddings
3+
description: Learn how to embed values with the Vercel AI SDK.
4+
---
5+
6+
# Embeddings
7+
8+
Embeddings are a way to represent words, phrases, or images as vectors in a high-dimensional space.
9+
In this space, similar words are close to each other, and the distance between words can be used to measure their similarity.
10+
11+
## Embedding a Single Value
12+
13+
The Vercel AI SDK provides the `embed` function to embed single values, which is useful for tasks such as finding similar words
14+
or phrases or clustering text. You can use it with embeddings models, e.g. `openai.embedding('text-embedding-3-large')` or `mistral.embedding('mistral-embed')`.
15+
16+
```tsx
17+
import { embed } from 'ai';
18+
import { openai } from '@ai-sdk/openai';
19+
20+
const { embedding } = await embed({
21+
model: openai.embedding('text-embedding-3-small'),
22+
value: 'sunny day at the beach',
23+
});
24+
```
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { mistral } from '@ai-sdk/mistral';
2+
import { embed } from 'ai';
3+
import dotenv from 'dotenv';
4+
5+
dotenv.config();
6+
7+
async function main() {
8+
const { embedding } = await embed({
9+
model: mistral.embedding('mistral-embed'),
10+
value: 'sunny day at the beach',
11+
});
12+
13+
console.log(embedding);
14+
}
15+
16+
main().catch(console.error);
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { embed } from 'ai';
3+
import dotenv from 'dotenv';
4+
5+
dotenv.config();
6+
7+
async function main() {
8+
const { embedding } = await embed({
9+
model: openai.embedding('text-embedding-3-small'),
10+
value: 'sunny day at the beach',
11+
});
12+
13+
console.log(embedding);
14+
}
15+
16+
main().catch(console.error);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import assert from 'node:assert';
2+
import { MockEmbeddingModelV1 } from '../test/mock-embedding-model-v1';
3+
import { embed } from './embed';
4+
5+
const dummyEmbedding = [0.1, 0.2, 0.3];
6+
const testValue = 'sunny day at the beach';
7+
8+
describe('result.embedding', () => {
9+
it('should generate embedding', async () => {
10+
const result = await embed({
11+
model: new MockEmbeddingModelV1({
12+
doEmbed: async ({ values }) => {
13+
assert.deepStrictEqual(values, [testValue]);
14+
15+
return {
16+
embeddings: [dummyEmbedding],
17+
};
18+
},
19+
}),
20+
value: testValue,
21+
});
22+
23+
assert.deepStrictEqual(result.embedding, dummyEmbedding);
24+
});
25+
});

‎packages/core/core/embed/embed.ts‎

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import { Embedding, EmbeddingModel } from '../types';
2+
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
3+
4+
/**
5+
Embed a value using an embedding model. The type of the value is defined by the embedding model.
6+
7+
@param model - The embedding model to use.
8+
@param value - The value that should be embedded.
9+
10+
@param maxRetries - Maximum number of retries. Set to 0 to disable retries. Default: 2.
11+
@param abortSignal - An optional abort signal that can be used to cancel the call.
12+
13+
@returns A result object that contains the embedding, the value, and additional information.
14+
*/
15+
export async function embed<VALUE>({
16+
model,
17+
value,
18+
maxRetries,
19+
abortSignal,
20+
}: {
21+
/**
22+
The embedding model to use.
23+
*/
24+
model: EmbeddingModel<VALUE>;
25+
26+
/**
27+
The value that should be embedded.
28+
*/
29+
value: VALUE;
30+
31+
/**
32+
Maximum number of retries per embedding model call. Set to 0 to disable retries.
33+
34+
@default 2
35+
*/
36+
maxRetries?: number;
37+
38+
/**
39+
Abort signal.
40+
*/
41+
abortSignal?: AbortSignal;
42+
}): Promise<EmbedResult<VALUE>> {
43+
const retry = retryWithExponentialBackoff({ maxRetries });
44+
45+
const modelResponse = await retry(() =>
46+
model.doEmbed({
47+
values: [value],
48+
abortSignal,
49+
}),
50+
);
51+
52+
return new EmbedResult({
53+
value,
54+
embedding: modelResponse.embeddings[0],
55+
rawResponse: modelResponse.rawResponse,
56+
});
57+
}
58+
59+
/**
60+
The result of a `embed` call.
61+
It contains the embedding, the value, and additional information.
62+
*/
63+
export class EmbedResult<VALUE> {
64+
/**
65+
The value that was embedded.
66+
*/
67+
readonly value: VALUE;
68+
69+
/**
70+
The embedding of the value.
71+
*/
72+
readonly embedding: Embedding;
73+
74+
/**
75+
Optional raw response data.
76+
*/
77+
readonly rawResponse?: {
78+
/**
79+
Response headers.
80+
*/
81+
headers?: Record<string, string>;
82+
};
83+
84+
constructor(options: {
85+
value: VALUE;
86+
embedding: Embedding;
87+
rawResponse?: {
88+
headers?: Record<string, string>;
89+
};
90+
}) {
91+
this.value = options.value;
92+
this.embedding = options.embedding;
93+
this.rawResponse = options.rawResponse;
94+
}
95+
}

‎packages/core/core/embed/index.ts‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export * from './embed';

‎packages/core/core/index.ts‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
export * from './embed';
12
export * from './generate-object';
23
export * from './generate-text';
34
export * from './prompt';
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import { EmbeddingModelV1 } from '@ai-sdk/provider';
2+
3+
export class MockEmbeddingModelV1<VALUE> implements EmbeddingModelV1<VALUE> {
4+
readonly specificationVersion = 'v1';
5+
6+
readonly provider: EmbeddingModelV1<VALUE>['provider'];
7+
readonly modelId: EmbeddingModelV1<VALUE>['modelId'];
8+
readonly maxEmbeddingsPerCall: EmbeddingModelV1<VALUE>['maxEmbeddingsPerCall'];
9+
readonly supportsParallelCalls: EmbeddingModelV1<VALUE>['supportsParallelCalls'];
10+
11+
doEmbed: EmbeddingModelV1<VALUE>['doEmbed'];
12+
13+
constructor({
14+
provider = 'mock-provider',
15+
modelId = 'mock-model-id',
16+
maxEmbeddingsPerCall = 1,
17+
supportsParallelCalls = false,
18+
doEmbed = notImplemented,
19+
}: {
20+
provider?: EmbeddingModelV1<VALUE>['provider'];
21+
modelId?: EmbeddingModelV1<VALUE>['modelId'];
22+
maxEmbeddingsPerCall?: EmbeddingModelV1<VALUE>['maxEmbeddingsPerCall'];
23+
supportsParallelCalls?: EmbeddingModelV1<VALUE>['supportsParallelCalls'];
24+
doEmbed?: EmbeddingModelV1<VALUE>['doEmbed'];
25+
}) {
26+
this.provider = provider;
27+
this.modelId = modelId;
28+
this.maxEmbeddingsPerCall = maxEmbeddingsPerCall;
29+
this.supportsParallelCalls = supportsParallelCalls;
30+
this.doEmbed = doEmbed;
31+
}
32+
}
33+
34+
function notImplemented(): never {
35+
throw new Error('Not implemented');
36+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import { EmbeddingModelV1, EmbeddingModelV1Embedding } from '@ai-sdk/provider';
2+
3+
/**
4+
Embedding model that is used by the AI SDK Core functions.
5+
*/
6+
export type EmbeddingModel<VALUE> = EmbeddingModelV1<VALUE>;
7+
8+
/**
9+
Embedding.
10+
*/
11+
export type Embedding = EmbeddingModelV1Embedding;

0 commit comments

Comments
 (0)