@@ -1112,6 +1112,8 @@ def visitModule(self, mod):
11121112 for dfn in mod .dfns :
11131113 self .visit (dfn )
11141114 self .file .write (textwrap .dedent ('''
1115+ state->recursion_depth = 0;
1116+ state->recursion_limit = 0;
11151117 state->initialized = 1;
11161118 return 1;
11171119 }
@@ -1259,8 +1261,14 @@ def func_begin(self, name):
12591261 self .emit ('if (!o) {' , 1 )
12601262 self .emit ("Py_RETURN_NONE;" , 2 )
12611263 self .emit ("}" , 1 )
1264+ self .emit ("if (++state->recursion_depth > state->recursion_limit) {" , 1 )
1265+ self .emit ("PyErr_SetString(PyExc_RecursionError," , 2 )
1266+ self .emit ('"maximum recursion depth exceeded during ast construction");' , 3 )
1267+ self .emit ("return 0;" , 2 )
1268+ self .emit ("}" , 1 )
12621269
12631270 def func_end (self ):
1271+ self .emit ("state->recursion_depth--;" , 1 )
12641272 self .emit ("return result;" , 1 )
12651273 self .emit ("failed:" , 0 )
12661274 self .emit ("Py_XDECREF(value);" , 1 )
@@ -1371,7 +1379,32 @@ class PartingShots(StaticVisitor):
13711379 if (state == NULL) {
13721380 return NULL;
13731381 }
1374- return ast2obj_mod(state, t);
1382+
1383+ int recursion_limit = Py_GetRecursionLimit();
1384+ int starting_recursion_depth;
1385+ /* Be careful here to prevent overflow. */
1386+ int COMPILER_STACK_FRAME_SCALE = 3;
1387+ PyThreadState *tstate = _PyThreadState_GET();
1388+ if (!tstate) {
1389+ return 0;
1390+ }
1391+ state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1392+ recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
1393+ int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
1394+ starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1395+ recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
1396+ state->recursion_depth = starting_recursion_depth;
1397+
1398+ PyObject *result = ast2obj_mod(state, t);
1399+
1400+ /* Check that the recursion depth counting balanced correctly */
1401+ if (result && state->recursion_depth != starting_recursion_depth) {
1402+ PyErr_Format(PyExc_SystemError,
1403+ "AST constructor recursion depth mismatch (before=%d, after=%d)",
1404+ starting_recursion_depth, state->recursion_depth);
1405+ return 0;
1406+ }
1407+ return result;
13751408}
13761409
13771410/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */
@@ -1437,6 +1470,8 @@ def visit(self, object):
14371470def generate_ast_state (module_state , f ):
14381471 f .write ('struct ast_state {\n ' )
14391472 f .write (' int initialized;\n ' )
1473+ f .write (' int recursion_depth;\n ' )
1474+ f .write (' int recursion_limit;\n ' )
14401475 for s in module_state :
14411476 f .write (' PyObject *' + s + ';\n ' )
14421477 f .write ('};' )
0 commit comments