Skip to content

Commit 76d7309

Browse files
Auto merge of #150820 - clubby789:br-weights, r=<try>
codegen: Use branch weights instead of `llvm.expect`
2 parents a3f2d5a + be17c64 commit 76d7309

File tree

3 files changed

+215
-111
lines changed

3 files changed

+215
-111
lines changed

‎compiler/rustc_codegen_llvm/src/builder.rs‎

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,32 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
359359
}
360360
}
361361

362+
fn cond_br_with_weight(
363+
&mut self,
364+
cond: Self::Value,
365+
then_llbb: Self::BasicBlock,
366+
else_llbb: Self::BasicBlock,
367+
then_cold: bool,
368+
) {
369+
if self.cx.sess().opts.optimize == rustc_session::config::OptLevel::No {
370+
self.cond_br(cond, then_llbb, else_llbb);
371+
return;
372+
}
373+
374+
let id = self.cx.create_metadata(b"branch_weights");
375+
376+
let cold_weight = llvm::LLVMValueAsMetadata(self.cx.const_u32(1));
377+
let hot_weight = llvm::LLVMValueAsMetadata(self.cx.const_u32(2000));
378+
let (then_weight, else_weight) =
379+
if then_cold { (cold_weight, hot_weight) } else { (hot_weight, cold_weight) };
380+
381+
let md: SmallVec<[&Metadata; 3]> = SmallVec::from_buf([id, then_weight, else_weight]);
382+
383+
let branch = unsafe { llvm::LLVMBuildCondBr(self.llbuilder, cond, then_llbb, else_llbb) };
384+
385+
self.cx.set_metadata_node(branch, llvm::MD_prof, &md);
386+
}
387+
362388
fn switch_with_weights(
363389
&mut self,
364390
v: Self::Value,

‎compiler/rustc_codegen_ssa/src/mir/block.rs‎

Lines changed: 174 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use rustc_session::config::OptLevel;
1515
use rustc_span::Span;
1616
use rustc_span::source_map::Spanned;
1717
use rustc_target::callconv::{ArgAbi, ArgAttributes, CastTarget, FnAbi, PassMode};
18+
use smallvec::SmallVec;
1819
use tracing::{debug, info};
1920

2021
use super::operand::OperandRef;
@@ -391,124 +392,70 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
391392
return;
392393
};
393394

394-
let mut target_iter = targets.iter();
395-
if target_iter.len() == 1 {
396-
// If there are two targets (one conditional, one fallback), emit `br` instead of
397-
// `switch`.
398-
let (test_value, target) = target_iter.next().unwrap();
399-
let otherwise = targets.otherwise();
400-
let lltarget = helper.llbb_with_cleanup(self, target);
401-
let llotherwise = helper.llbb_with_cleanup(self, otherwise);
402-
let target_cold = self.cold_blocks[target];
403-
let otherwise_cold = self.cold_blocks[otherwise];
404-
// If `target_cold == otherwise_cold`, the branches have the same weight
405-
// so there is no expectation. If they differ, the `target` branch is expected
406-
// when the `otherwise` branch is cold.
407-
let expect = if target_cold == otherwise_cold { None } else { Some(otherwise_cold) };
408-
if switch_ty == bx.tcx().types.bool {
409-
// Don't generate trivial icmps when switching on bool.
410-
match test_value {
411-
0 => {
412-
let expect = expect.map(|e| !e);
413-
bx.cond_br_with_expect(discr_value, llotherwise, lltarget, expect);
414-
}
415-
1 => {
416-
bx.cond_br_with_expect(discr_value, lltarget, llotherwise, expect);
417-
}
418-
_ => bug!(),
395+
let bool_ty = bx.tcx().types.bool;
396+
let normalized = self.normalize_switch_targets(
397+
targets,
398+
switch_ty == bool_ty,
399+
self.cx.sess().opts.optimize,
400+
);
401+
match normalized {
402+
NormalizedSwitch::BooleanBranch { then_bb, then_cold, else_bb, needs_trunc } => {
403+
let discr_value = if needs_trunc {
404+
let bool_llty = bx.immediate_backend_type(bx.layout_of(bool_ty));
405+
bx.unchecked_utrunc(discr_value, bool_llty)
406+
} else {
407+
discr_value
408+
};
409+
if let Some(then_cold) = then_cold {
410+
bx.cond_br_with_weight(
411+
discr_value,
412+
helper.llbb_with_cleanup(self, then_bb),
413+
helper.llbb_with_cleanup(self, else_bb),
414+
then_cold,
415+
)
416+
} else {
417+
bx.cond_br(
418+
discr_value,
419+
helper.llbb_with_cleanup(self, then_bb),
420+
helper.llbb_with_cleanup(self, else_bb),
421+
)
419422
}
420-
} else {
423+
}
424+
NormalizedSwitch::Branch { then_bb: (test_value, then_bb), then_cold, else_bb } => {
421425
let switch_llty = bx.immediate_backend_type(bx.layout_of(switch_ty));
422426
let llval = bx.const_uint_big(switch_llty, test_value);
423427
let cmp = bx.icmp(IntPredicate::IntEQ, discr_value, llval);
424-
bx.cond_br_with_expect(cmp, lltarget, llotherwise, expect);
428+
let lltarget = helper.llbb_with_cleanup(self, then_bb);
429+
let llotherwise = helper.llbb_with_cleanup(self, else_bb);
430+
431+
if let Some(then_cold) = then_cold {
432+
bx.cond_br_with_weight(cmp, lltarget, llotherwise, then_cold)
433+
} else {
434+
bx.cond_br(cmp, lltarget, llotherwise)
435+
}
425436
}
426-
} else if target_iter.len() == 2
427-
&& self.mir[targets.otherwise()].is_empty_unreachable()
428-
&& targets.all_values().contains(&Pu128(0))
429-
&& targets.all_values().contains(&Pu128(1))
430-
{
431-
// This is the really common case for `bool`, `Option`, etc.
432-
// By using `trunc nuw` we communicate that other values are
433-
// impossible without needing `switch` or `assume`s.
434-
let true_bb = targets.target_for_value(1);
435-
let false_bb = targets.target_for_value(0);
436-
let true_ll = helper.llbb_with_cleanup(self, true_bb);
437-
let false_ll = helper.llbb_with_cleanup(self, false_bb);
438-
439-
let expected_cond_value = if self.cx.sess().opts.optimize == OptLevel::No {
440-
None
441-
} else {
442-
match (self.cold_blocks[true_bb], self.cold_blocks[false_bb]) {
443-
// Same coldness, no expectation
444-
(true, true) | (false, false) => None,
445-
// Different coldness, expect the non-cold one
446-
(true, false) => Some(false),
447-
(false, true) => Some(true),
437+
NormalizedSwitch::Switch { values, targets, targets_cold } => {
438+
let (&else_bb, targets) = targets.split_last().unwrap();
439+
let else_llbb = helper.llbb_with_cleanup(self, else_bb);
440+
let cases = values
441+
.iter()
442+
.zip(targets)
443+
.map(|(&value, &target)| (value.get(), helper.llbb_with_cleanup(self, target)));
444+
if let Some(targets_cold) = targets_cold {
445+
let (&else_cold, targets_cold) = targets_cold.split_last().unwrap();
446+
bx.switch_with_weights(
447+
discr_value,
448+
else_llbb,
449+
else_cold,
450+
cases
451+
.zip(targets_cold)
452+
.map(|((value, target), &cold)| (value, target, cold)),
453+
);
454+
} else {
455+
bx.switch(discr_value, else_llbb, cases);
448456
}
449-
};
450-
451-
let bool_ty = bx.tcx().types.bool;
452-
let cond = if switch_ty == bool_ty {
453-
discr_value
454-
} else {
455-
let bool_llty = bx.immediate_backend_type(bx.layout_of(bool_ty));
456-
bx.unchecked_utrunc(discr_value, bool_llty)
457-
};
458-
bx.cond_br_with_expect(cond, true_ll, false_ll, expected_cond_value);
459-
} else if self.cx.sess().opts.optimize == OptLevel::No
460-
&& target_iter.len() == 2
461-
&& self.mir[targets.otherwise()].is_empty_unreachable()
462-
{
463-
// In unoptimized builds, if there are two normal targets and the `otherwise` target is
464-
// an unreachable BB, emit `br` instead of `switch`. This leaves behind the unreachable
465-
// BB, which will usually (but not always) be dead code.
466-
//
467-
// Why only in unoptimized builds?
468-
// - In unoptimized builds LLVM uses FastISel which does not support switches, so it
469-
// must fall back to the slower SelectionDAG isel. Therefore, using `br` gives
470-
// significant compile time speedups for unoptimized builds.
471-
// - In optimized builds the above doesn't hold, and using `br` sometimes results in
472-
// worse generated code because LLVM can no longer tell that the value being switched
473-
// on can only have two values, e.g. 0 and 1.
474-
//
475-
let (test_value1, target1) = target_iter.next().unwrap();
476-
let (_test_value2, target2) = target_iter.next().unwrap();
477-
let ll1 = helper.llbb_with_cleanup(self, target1);
478-
let ll2 = helper.llbb_with_cleanup(self, target2);
479-
let switch_llty = bx.immediate_backend_type(bx.layout_of(switch_ty));
480-
let llval = bx.const_uint_big(switch_llty, test_value1);
481-
let cmp = bx.icmp(IntPredicate::IntEQ, discr_value, llval);
482-
bx.cond_br(cmp, ll1, ll2);
483-
} else {
484-
let otherwise = targets.otherwise();
485-
let otherwise_cold = self.cold_blocks[otherwise];
486-
let otherwise_unreachable = self.mir[otherwise].is_empty_unreachable();
487-
let cold_count = targets.iter().filter(|(_, target)| self.cold_blocks[*target]).count();
488-
let none_cold = cold_count == 0;
489-
let all_cold = cold_count == targets.iter().len();
490-
if (none_cold && (!otherwise_cold || otherwise_unreachable))
491-
|| (all_cold && (otherwise_cold || otherwise_unreachable))
492-
{
493-
// All targets have the same weight,
494-
// or `otherwise` is unreachable and it's the only target with a different weight.
495-
bx.switch(
496-
discr_value,
497-
helper.llbb_with_cleanup(self, targets.otherwise()),
498-
target_iter
499-
.map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
500-
);
501-
} else {
502-
// Targets have different weights
503-
bx.switch_with_weights(
504-
discr_value,
505-
helper.llbb_with_cleanup(self, targets.otherwise()),
506-
otherwise_cold,
507-
target_iter.map(|(value, target)| {
508-
(value, helper.llbb_with_cleanup(self, target), self.cold_blocks[target])
509-
}),
510-
);
511457
}
458+
NormalizedSwitch::Jump { target } => bx.br(helper.llbb_with_cleanup(self, target)),
512459
}
513460
}
514461

@@ -2034,6 +1981,92 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
20341981
}
20351982
}
20361983
}
1984+
1985+
fn normalize_switch_targets(
1986+
&self,
1987+
targets: &SwitchTargets,
1988+
discr_boolean: bool,
1989+
opt_level: OptLevel,
1990+
) -> NormalizedSwitch {
1991+
let mut target_iter = targets.iter();
1992+
let target_iter_len = target_iter.len();
1993+
if target_iter_len == 0 {
1994+
return NormalizedSwitch::Jump { target: targets.otherwise() };
1995+
}
1996+
let mut targets_cold = targets.all_targets().iter().map(|&target| self.cold_blocks[target]);
1997+
let use_weights =
1998+
targets_cold.clone().any(|hot| hot) && targets_cold.clone().any(|hot| !hot);
1999+
2000+
// If there are more than two targets we need to switch.
2001+
// Additionally, if there are two targets, with an empty-unreachable other branch,
2002+
// and we are in an opt-level greater than 0, emit a switch.
2003+
// Why only in optimized builds?
2004+
// - In unoptimized builds LLVM uses FastISel which does not support switches, so it
2005+
// must fall back to the slower SelectionDAG isel. Therefore, using `br` gives
2006+
// significant compile time speedups for unoptimized builds.
2007+
// - In optimized builds the above doesn't hold, and using `br` sometimes results in
2008+
// worse generated code because LLVM can no longer tell that the value being switched
2009+
// on can only have two values, e.g. 0 and 1.
2010+
2011+
if target_iter_len > 2
2012+
|| target_iter_len == 2 && !self.mir[targets.otherwise()].is_empty_unreachable()
2013+
|| target_iter_len == 2
2014+
&& self.mir[targets.otherwise()].is_empty_unreachable()
2015+
&& opt_level != OptLevel::No
2016+
{
2017+
return NormalizedSwitch::Switch {
2018+
values: targets.all_values().into(),
2019+
targets: targets.all_targets().into(),
2020+
targets_cold: use_weights.then(|| targets_cold.collect()),
2021+
};
2022+
}
2023+
2024+
let then_bb = target_iter.next().unwrap();
2025+
let then_cold = targets_cold.next().unwrap();
2026+
let else_bb =
2027+
if target_iter_len == 1 { targets.otherwise() } else { target_iter.next().unwrap().1 };
2028+
2029+
if discr_boolean {
2030+
// Emit a `br i1`, swapping the argument order if required
2031+
let (test_value, then_bb) = then_bb;
2032+
let (then_bb, then_cold, else_bb) = if test_value == 1 {
2033+
(then_bb, then_cold, else_bb)
2034+
} else {
2035+
(else_bb, targets_cold.next().unwrap(), then_bb)
2036+
};
2037+
NormalizedSwitch::BooleanBranch {
2038+
needs_trunc: false,
2039+
then_bb,
2040+
then_cold: use_weights.then_some(then_cold),
2041+
else_bb,
2042+
}
2043+
} else if target_iter_len == 2
2044+
&& let &[Pu128(then_value), Pu128(else_value)] = targets.all_values()
2045+
&& ((then_value == 1 && else_value == 0) || (then_value == 0 && else_value == 1))
2046+
{
2047+
// Same as above, but with a non-boolean discriminant (e.g. an enum discriminant)
2048+
// This emits a `trunc nuw` to communicate that other values are impossible,
2049+
// without needing `switch` or `assume`/`expect`
2050+
let (then_bb, then_cold, else_bb) = if then_value == 1 {
2051+
(then_bb.1, then_cold, else_bb)
2052+
} else {
2053+
(else_bb, targets_cold.next().unwrap(), then_bb.1)
2054+
};
2055+
NormalizedSwitch::BooleanBranch {
2056+
needs_trunc: true,
2057+
then_bb,
2058+
then_cold: use_weights.then_some(then_cold),
2059+
else_bb,
2060+
}
2061+
} else {
2062+
// Need to emit an icmp and branch on it
2063+
NormalizedSwitch::Branch {
2064+
then_bb,
2065+
then_cold: use_weights.then_some(then_cold),
2066+
else_bb,
2067+
}
2068+
}
2069+
}
20372070
}
20382071

20392072
enum ReturnDest<'tcx, V> {
@@ -2090,3 +2123,33 @@ pub fn store_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
20902123
bx.store(value, ptr, align);
20912124
};
20922125
}
2126+
2127+
/// Normalized version of [`SwitchTargets`].
2128+
enum NormalizedSwitch {
2129+
/// Discriminant is a bool with 1=>then, 0=>else
2130+
/// If `needs_trunc` is true, this is not a boolean but some other
2131+
/// value which may be 0 or 1, and therefore needs `trunc nuw`
2132+
BooleanBranch {
2133+
needs_trunc: bool,
2134+
then_bb: mir::BasicBlock,
2135+
then_cold: Option<bool>,
2136+
else_bb: mir::BasicBlock,
2137+
// else_cold is the inverse of then_cold
2138+
},
2139+
/// If discr==then_bb.0=>then, otherwise=>else
2140+
Branch {
2141+
then_bb: (u128, mir::BasicBlock),
2142+
then_cold: Option<bool>,
2143+
else_bb: mir::BasicBlock,
2144+
// else_cold is the inverse of then_cold
2145+
},
2146+
/// Equivalent to [`SwitchTargets`], but known to have at least 3 targets
2147+
Switch {
2148+
values: SmallVec<[Pu128; 2]>,
2149+
targets: SmallVec<[mir::BasicBlock; 3]>,
2150+
targets_cold: Option<SmallVec<[bool; 3]>>,
2151+
},
2152+
Jump {
2153+
target: mir::BasicBlock,
2154+
},
2155+
}

‎compiler/rustc_codegen_ssa/src/traits/builder.rs‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,21 @@ pub trait BuilderMethods<'a, 'tcx>:
103103
self.cond_br(cond, then_llbb, else_llbb)
104104
}
105105

106+
// Conditional with weights.
107+
//
108+
// This function is opt-in for back ends.
109+
//
110+
// The default implementation ignores `then_cold` and calls `self.cond_br()`
111+
fn cond_br_with_weight(
112+
&mut self,
113+
cond: Self::Value,
114+
then_llbb: Self::BasicBlock,
115+
else_llbb: Self::BasicBlock,
116+
_then_cold: bool,
117+
) {
118+
self.cond_br(cond, then_llbb, else_llbb)
119+
}
120+
106121
fn switch(
107122
&mut self,
108123
v: Self::Value,

0 commit comments

Comments
 (0)