Skip to content

Commit f94fc7e

Browse files
authored
[mypyc] Implement CallC IR (#8880)
Relates to mypyc/mypyc#709 This PR adds a new IR op CallC to replace some PrimitiveOp that simply calls a C function. To demonstrate this prototype, str.join primitive is now switched from PrimitiveOp to CallC, with identical generated C code.
1 parent b3d4398 commit f94fc7e

File tree

7 files changed

+125
-9
lines changed

7 files changed

+125
-9
lines changed

‎mypyc/analysis.py‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Value, ControlOp,
99
BasicBlock, OpVisitor, Assign, LoadInt, LoadErrorValue, RegisterOp, Goto, Branch, Return, Call,
1010
Environment, Box, Unbox, Cast, Op, Unreachable, TupleGet, TupleSet, GetAttr, SetAttr,
11-
LoadStatic, InitStatic, PrimitiveOp, MethodCall, RaiseStandardError,
11+
LoadStatic, InitStatic, PrimitiveOp, MethodCall, RaiseStandardError, CallC
1212
)
1313

1414

@@ -195,6 +195,9 @@ def visit_cast(self, op: Cast) -> GenAndKill:
195195
def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill:
196196
return self.visit_register_op(op)
197197

198+
def visit_call_c(self, op: CallC) -> GenAndKill:
199+
return self.visit_register_op(op)
200+
198201

199202
class DefinedVisitor(BaseAnalysisVisitor):
200203
"""Visitor for finding defined registers.

‎mypyc/codegen/emitfunc.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
OpVisitor, Goto, Branch, Return, Assign, LoadInt, LoadErrorValue, GetAttr, SetAttr,
1212
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
1313
BasicBlock, Value, MethodCall, PrimitiveOp, EmitterInterface, Unreachable, NAMESPACE_STATIC,
14-
NAMESPACE_TYPE, NAMESPACE_MODULE, RaiseStandardError
14+
NAMESPACE_TYPE, NAMESPACE_MODULE, RaiseStandardError, CallC
1515
)
1616
from mypyc.ir.rtypes import RType, RTuple
1717
from mypyc.ir.func_ir import FuncIR, FuncDecl, FUNC_STATICMETHOD, FUNC_CLASSMETHOD
@@ -415,6 +415,11 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
415415
self.emitter.emit_line('PyErr_SetNone(PyExc_{});'.format(op.class_name))
416416
self.emitter.emit_line('{} = 0;'.format(self.reg(op)))
417417

418+
def visit_call_c(self, op: CallC) -> None:
419+
dest = self.get_dest_assign(op)
420+
args = ', '.join(self.reg(arg) for arg in op.args)
421+
self.emitter.emit_line("{}{}({});".format(dest, op.function_name, args))
422+
418423
# Helpers
419424

420425
def label(self, label: BasicBlock) -> str:

‎mypyc/ir/ops.py‎

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,31 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:
11381138
return visitor.visit_raise_standard_error(self)
11391139

11401140

1141+
class CallC(RegisterOp):
1142+
"""ret = func_call(arg0, arg1, ...)
1143+
1144+
A call to a C function
1145+
"""
1146+
1147+
error_kind = ERR_MAGIC
1148+
1149+
def __init__(self, function_name: str, args: List[Value], ret_type: RType, line: int) -> None:
1150+
super().__init__(line)
1151+
self.function_name = function_name
1152+
self.args = args
1153+
self.type = ret_type
1154+
1155+
def to_str(self, env: Environment) -> str:
1156+
args_str = ', '.join(env.format('%r', arg) for arg in self.args)
1157+
return env.format('%r = %s(%s)', self, self.function_name, args_str)
1158+
1159+
def sources(self) -> List[Value]:
1160+
return self.args
1161+
1162+
def accept(self, visitor: 'OpVisitor[T]') -> T:
1163+
return visitor.visit_call_c(self)
1164+
1165+
11411166
@trait
11421167
class OpVisitor(Generic[T]):
11431168
"""Generic visitor over ops (uses the visitor design pattern)."""
@@ -1228,6 +1253,10 @@ def visit_unbox(self, op: Unbox) -> T:
12281253
def visit_raise_standard_error(self, op: RaiseStandardError) -> T:
12291254
raise NotImplementedError
12301255

1256+
@abstractmethod
1257+
def visit_call_c(self, op: CallC) -> T:
1258+
raise NotImplementedError
1259+
12311260

12321261
# TODO: Should this live somewhere else?
12331262
LiteralsMap = Dict[Tuple[Type[object], Union[int, float, str, bytes, complex]], str]

‎mypyc/irbuild/ll_builder.py‎

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@
2020
BasicBlock, Environment, Op, LoadInt, Value, Register,
2121
Assign, Branch, Goto, Call, Box, Unbox, Cast, GetAttr,
2222
LoadStatic, MethodCall, PrimitiveOp, OpDescription, RegisterOp,
23-
NAMESPACE_TYPE, NAMESPACE_MODULE, LoadErrorValue,
23+
NAMESPACE_TYPE, NAMESPACE_MODULE, LoadErrorValue, CallC
2424
)
2525
from mypyc.ir.rtypes import (
2626
RType, RUnion, RInstance, optional_value_type, int_rprimitive, float_rprimitive,
27-
bool_rprimitive, list_rprimitive, str_rprimitive, is_none_rprimitive, object_rprimitive
27+
bool_rprimitive, list_rprimitive, str_rprimitive, is_none_rprimitive, object_rprimitive,
28+
void_rtype
2829
)
2930
from mypyc.ir.func_ir import FuncDecl, FuncSignature
3031
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
3132
from mypyc.common import (
3233
FAST_ISINSTANCE_MAX_SUBCLASSES, MAX_LITERAL_SHORT_INT,
3334
)
34-
from mypyc.primitives.registry import binary_ops, unary_ops, method_ops, func_ops
35+
from mypyc.primitives.registry import (
36+
binary_ops, unary_ops, method_ops, func_ops,
37+
c_method_call_ops, CFunctionDescription
38+
)
3539
from mypyc.primitives.list_ops import (
3640
list_extend_op, list_len_op, new_list_op
3741
)
@@ -644,6 +648,41 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
644648
value = self.primitive_op(bool_op, [value], value.line)
645649
self.add(Branch(value, true, false, Branch.BOOL_EXPR))
646650

651+
def call_c(self,
652+
function_name: str,
653+
args: List[Value],
654+
line: int,
655+
result_type: Optional[RType]) -> Value:
656+
# handle void function via singleton RVoid instance
657+
ret_type = void_rtype if result_type is None else result_type
658+
target = self.add(CallC(function_name, args, ret_type, line))
659+
return target
660+
661+
def matching_call_c(self,
662+
candidates: List[CFunctionDescription],
663+
args: List[Value],
664+
line: int,
665+
result_type: Optional[RType] = None) -> Optional[Value]:
666+
# TODO: this function is very similar to matching_primitive_op
667+
# we should remove the old one or refactor both them into only as we move forward
668+
matching = None # type: Optional[CFunctionDescription]
669+
for desc in candidates:
670+
if len(desc.arg_types) != len(args):
671+
continue
672+
if all(is_subtype(actual.type, formal)
673+
for actual, formal in zip(args, desc.arg_types)):
674+
if matching:
675+
assert matching.priority != desc.priority, 'Ambiguous:\n1) %s\n2) %s' % (
676+
matching, desc)
677+
if desc.priority > matching.priority:
678+
matching = desc
679+
else:
680+
matching = desc
681+
if matching:
682+
target = self.call_c(matching.c_function_name, args, line, result_type)
683+
return target
684+
return None
685+
647686
# Internal helpers
648687

649688
def decompose_union_helper(self,
@@ -728,6 +767,11 @@ def translate_special_method_call(self,
728767
Return None if no translation found; otherwise return the target register.
729768
"""
730769
ops = method_ops.get(name, [])
770+
call_c_ops_candidates = c_method_call_ops.get(name, [])
771+
call_c_op = self.matching_call_c(call_c_ops_candidates, [base_reg] + args, line,
772+
result_type=result_type)
773+
if call_c_op is not None:
774+
return call_c_op
731775
return self.matching_primitive_op(ops, [base_reg] + args, line, result_type=result_type)
732776

733777
def translate_eq_cmp(self,

‎mypyc/primitives/registry.py‎

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,20 @@
3535
optimized implementations of all ops.
3636
"""
3737

38-
from typing import Dict, List, Optional
38+
from typing import Dict, List, Optional, NamedTuple
3939

4040
from mypyc.ir.ops import (
4141
OpDescription, EmitterInterface, EmitCallback, StealsDescription, short_name
4242
)
4343
from mypyc.ir.rtypes import RType, bool_rprimitive
4444

45+
CFunctionDescription = NamedTuple(
46+
'CFunctionDescription', [('name', str),
47+
('arg_types', List[RType]),
48+
('result_type', Optional[RType]),
49+
('c_function_name', str),
50+
('error_kind', int),
51+
('priority', int)])
4552

4653
# Primitive binary ops (key is operator such as '+')
4754
binary_ops = {} # type: Dict[str, List[OpDescription]]
@@ -58,6 +65,8 @@
5865
# Primitive ops for reading module attributes (key is name such as 'builtins.None')
5966
name_ref_ops = {} # type: Dict[str, OpDescription]
6067

68+
c_method_call_ops = {} # type: Dict[str, List[CFunctionDescription]]
69+
6170

6271
def simple_emit(template: str) -> EmitCallback:
6372
"""Construct a simple PrimitiveOp emit callback function.
@@ -312,6 +321,18 @@ def custom_op(arg_types: List[RType],
312321
emit, steals, is_borrowed, 0)
313322

314323

324+
def c_method_op(name: str,
325+
arg_types: List[RType],
326+
result_type: Optional[RType],
327+
c_function_name: str,
328+
error_kind: int,
329+
priority: int = 1) -> None:
330+
ops = c_method_call_ops.setdefault(name, [])
331+
desc = CFunctionDescription(name, arg_types, result_type,
332+
c_function_name, error_kind, priority)
333+
ops.append(desc)
334+
335+
315336
# Import various modules that set up global state.
316337
import mypyc.primitives.int_ops # noqa
317338
import mypyc.primitives.str_ops # noqa

‎mypyc/primitives/str_ops.py‎

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99
from mypyc.primitives.registry import (
1010
func_op, binary_op, simple_emit, name_ref_op, method_op, call_emit, name_emit,
11+
c_method_op
1112
)
1213

1314

@@ -33,12 +34,13 @@
3334
emit=call_emit('PyUnicode_Concat'))
3435

3536
# str.join(obj)
36-
method_op(
37+
c_method_op(
3738
name='join',
3839
arg_types=[str_rprimitive, object_rprimitive],
3940
result_type=str_rprimitive,
40-
error_kind=ERR_MAGIC,
41-
emit=call_emit('PyUnicode_Join'))
41+
c_function_name='PyUnicode_Join',
42+
error_kind=ERR_MAGIC
43+
)
4244

4345
# str[index] (for an int index)
4446
method_op(

‎mypyc/test-data/irbuild-basic.test‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,3 +3383,15 @@ L0:
33833383
r5 = None
33843384
return r5
33853385

3386+
[case testCallCWithStrJoin]
3387+
from typing import List
3388+
def f(x: str, y: List[str]) -> str:
3389+
return x.join(y)
3390+
[out]
3391+
def f(x, y):
3392+
x :: str
3393+
y :: list
3394+
r0 :: str
3395+
L0:
3396+
r0 = PyUnicode_Join(x, y)
3397+
return r0

0 commit comments

Comments
 (0)