Skip to content

Commit 0e1da47

Browse files
authored
feat (core): add maxAutomaticRoundtrips setting to generateText (#1750)
1 parent 7f9fa63 commit 0e1da47

File tree

8 files changed

+485
-97
lines changed

8 files changed

+485
-97
lines changed

‎.changeset/sour-eagles-tease.md‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (core): add maxAutomaticRoundtrips setting to generateText

‎content/docs/07-reference/ai-sdk-core/01-generate-text.mdx‎

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ const result = await generateText({
2121
});
2222
```
2323

24-
## Parameters
24+
## API Signature
25+
26+
### Parameters
2527

2628
<PropertiesTable
2729
content={[
@@ -317,10 +319,17 @@ const result = await generateText({
317319
description:
318320
'An optional abort signal that can be used to cancel the call.',
319321
},
322+
{
323+
name: 'maxAutomaticRoundtrips',
324+
type: 'number',
325+
isOptional: true,
326+
description:
327+
'Maximum number of automatic roundtrips for tool calls. An automatic tool call roundtrip is another LLM call with the tool call results when all tool calls of the last assistant message have results. A maximum number is required to prevent infinite loops in the case of misconfigured tools. By default, it is set to 0, which will disable the feature.',
328+
},
320329
]}
321330
/>
322331

323-
## Result Object
332+
### Returns
324333

325334
<PropertiesTable
326335
content={[

‎examples/ai-core/package.json‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"@ai-sdk/openai": "latest",
1111
"ai": "latest",
1212
"dotenv": "16.4.5",
13+
"mathjs": "12.4.2",
1314
"zod": "3.23.8",
1415
"zod-to-json-schema": "3.22.4"
1516
},
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { generateText, tool } from 'ai';
3+
import dotenv from 'dotenv';
4+
import * as mathjs from 'mathjs';
5+
import { z } from 'zod';
6+
7+
dotenv.config();
8+
9+
const problem =
10+
'A taxi driver earns $9461 per 1-hour work. ' +
11+
'If he works 12 hours a day and in 1 hour he uses 12-liters petrol with price $134 for 1-liter. ' +
12+
'How much money does he earn in one day?';
13+
14+
async function main() {
15+
console.log(`PROBLEM: ${problem}\n`);
16+
17+
await generateText({
18+
model: openai('gpt-4-turbo'),
19+
system:
20+
'You are solving math problems. ' +
21+
'Reason step by step. ' +
22+
'Use the calculator when necessary. ' +
23+
'The calculator can only do simple additions, subtractions, multiplications, and divisions. ' +
24+
'When you give the final answer, provide an explanation for how you got it.',
25+
prompt: problem,
26+
tools: {
27+
calculate: tool({
28+
description:
29+
'A tool for evaluating mathematical expressions. Example expressions: ' +
30+
"'1.2 * (2 + 4.5)', '12.7 cm to inch', 'sin(45 deg) ^ 2'.",
31+
parameters: z.object({ expression: z.string() }),
32+
execute: async ({ expression }) => mathjs.evaluate(expression),
33+
}),
34+
answer: tool({
35+
description: 'A tool for providing the final answer.',
36+
parameters: z.object({ answer: z.string() }),
37+
execute: async ({ answer }) => {
38+
console.log(`ANSWER: ${answer}`);
39+
process.exit(0);
40+
},
41+
}),
42+
},
43+
toolChoice: 'required',
44+
maxAutomaticRoundtrips: 10,
45+
});
46+
}
47+
48+
main().catch(console.error);

‎packages/core/core/generate-text/generate-text.test.ts‎

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,4 +268,211 @@ describe('result.responseMessages', () => {
268268
},
269269
]);
270270
});
271+
272+
it('should contain assistant response message and tool message from all roundtrips', async () => {
273+
let responseCount = 0;
274+
const result = await generateText({
275+
model: new MockLanguageModelV1({
276+
doGenerate: async ({ prompt, mode }) => {
277+
switch (responseCount++) {
278+
case 0:
279+
return {
280+
...dummyResponseValues,
281+
toolCalls: [
282+
{
283+
toolCallType: 'function',
284+
toolCallId: 'call-1',
285+
toolName: 'tool1',
286+
args: `{ "value": "value" }`,
287+
},
288+
],
289+
toolResults: [
290+
{
291+
toolCallId: 'call-1',
292+
toolName: 'tool1',
293+
args: { value: 'value' },
294+
result: 'result1',
295+
},
296+
],
297+
};
298+
case 1:
299+
return {
300+
...dummyResponseValues,
301+
text: 'Hello, world!',
302+
};
303+
default:
304+
throw new Error(`Unexpected response count: ${responseCount}`);
305+
}
306+
},
307+
}),
308+
tools: {
309+
tool1: {
310+
parameters: z.object({ value: z.string() }),
311+
execute: async args => {
312+
assert.deepStrictEqual(args, { value: 'value' });
313+
return 'result1';
314+
},
315+
},
316+
},
317+
prompt: 'test-input',
318+
maxAutomaticRoundtrips: 2,
319+
});
320+
321+
assert.deepStrictEqual(result.responseMessages, [
322+
{
323+
role: 'assistant',
324+
content: [
325+
{ type: 'text', text: '' },
326+
{
327+
type: 'tool-call',
328+
toolCallId: 'call-1',
329+
toolName: 'tool1',
330+
args: { value: 'value' },
331+
},
332+
],
333+
},
334+
{
335+
role: 'tool',
336+
content: [
337+
{
338+
type: 'tool-result',
339+
toolCallId: 'call-1',
340+
toolName: 'tool1',
341+
result: 'result1',
342+
},
343+
],
344+
},
345+
{
346+
role: 'assistant',
347+
content: [{ type: 'text', text: 'Hello, world!' }],
348+
},
349+
]);
350+
});
351+
});
352+
353+
describe('maxAutomaticRoundtrips', () => {
354+
it('should return text, tool calls and tool results from last roundtrip', async () => {
355+
let responseCount = 0;
356+
const result = await generateText({
357+
model: new MockLanguageModelV1({
358+
doGenerate: async ({ prompt, mode }) => {
359+
switch (responseCount++) {
360+
case 0:
361+
assert.deepStrictEqual(mode, {
362+
type: 'regular',
363+
toolChoice: { type: 'auto' },
364+
tools: [
365+
{
366+
type: 'function',
367+
name: 'tool1',
368+
description: undefined,
369+
parameters: {
370+
$schema: 'http://json-schema.org/draft-07/schema#',
371+
additionalProperties: false,
372+
properties: { value: { type: 'string' } },
373+
required: ['value'],
374+
type: 'object',
375+
},
376+
},
377+
],
378+
});
379+
assert.deepStrictEqual(prompt, [
380+
{
381+
role: 'user',
382+
content: [{ type: 'text', text: 'test-input' }],
383+
},
384+
]);
385+
return {
386+
...dummyResponseValues,
387+
toolCalls: [
388+
{
389+
toolCallType: 'function',
390+
toolCallId: 'call-1',
391+
toolName: 'tool1',
392+
args: `{ "value": "value" }`,
393+
},
394+
],
395+
toolResults: [
396+
{
397+
toolCallId: 'call-1',
398+
toolName: 'tool1',
399+
args: { value: 'value' },
400+
result: 'result1',
401+
},
402+
],
403+
};
404+
case 1:
405+
assert.deepStrictEqual(mode, {
406+
type: 'regular',
407+
toolChoice: { type: 'auto' },
408+
tools: [
409+
{
410+
type: 'function',
411+
name: 'tool1',
412+
description: undefined,
413+
parameters: {
414+
$schema: 'http://json-schema.org/draft-07/schema#',
415+
additionalProperties: false,
416+
properties: { value: { type: 'string' } },
417+
required: ['value'],
418+
type: 'object',
419+
},
420+
},
421+
],
422+
});
423+
assert.deepStrictEqual(prompt, [
424+
{
425+
role: 'user',
426+
content: [{ type: 'text', text: 'test-input' }],
427+
},
428+
{
429+
role: 'assistant',
430+
content: [
431+
{ type: 'text', text: '' },
432+
{
433+
type: 'tool-call',
434+
toolCallId: 'call-1',
435+
toolName: 'tool1',
436+
args: { value: 'value' },
437+
},
438+
],
439+
},
440+
{
441+
role: 'tool',
442+
content: [
443+
{
444+
type: 'tool-result',
445+
toolCallId: 'call-1',
446+
toolName: 'tool1',
447+
result: 'result1',
448+
},
449+
],
450+
},
451+
]);
452+
return {
453+
...dummyResponseValues,
454+
text: 'Hello, world!',
455+
};
456+
default:
457+
throw new Error(`Unexpected response count: ${responseCount}`);
458+
}
459+
},
460+
}),
461+
tools: {
462+
tool1: {
463+
parameters: z.object({ value: z.string() }),
464+
execute: async args => {
465+
assert.deepStrictEqual(args, { value: 'value' });
466+
return 'result1';
467+
},
468+
},
469+
},
470+
prompt: 'test-input',
471+
maxAutomaticRoundtrips: 2,
472+
});
473+
474+
assert.deepStrictEqual(result.text, 'Hello, world!');
475+
assert.deepStrictEqual(result.toolCalls, []);
476+
assert.deepStrictEqual(result.toolResults, []);
477+
});
271478
});

0 commit comments

Comments
 (0)