Skip to content

Commit 30c1a3d

Browse files
Auto merge of #150820 - clubby789:br-weights, r=<try>
codegen: Use branch weights instead of `llvm.expect`
2 parents 85d0cdf + af884ea commit 30c1a3d

File tree

3 files changed

+190
-111
lines changed

3 files changed

+190
-111
lines changed

‎compiler/rustc_codegen_llvm/src/builder.rs‎

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,32 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
359359
}
360360
}
361361

362+
fn cond_br_with_weight(
363+
&mut self,
364+
cond: Self::Value,
365+
then_llbb: Self::BasicBlock,
366+
else_llbb: Self::BasicBlock,
367+
then_cold: bool,
368+
) {
369+
if self.cx.sess().opts.optimize == rustc_session::config::OptLevel::No {
370+
self.cond_br(cond, then_llbb, else_llbb);
371+
return;
372+
}
373+
374+
let id = self.cx.create_metadata(b"branch_weights");
375+
376+
let cold_weight = llvm::LLVMValueAsMetadata(self.cx.const_u32(1));
377+
let hot_weight = llvm::LLVMValueAsMetadata(self.cx.const_u32(2000));
378+
let (then_weight, else_weight) =
379+
if then_cold { (cold_weight, hot_weight) } else { (hot_weight, cold_weight) };
380+
381+
let md: SmallVec<[&Metadata; 3]> = SmallVec::from_buf([id, then_weight, else_weight]);
382+
383+
let branch = unsafe { llvm::LLVMBuildCondBr(self.llbuilder, cond, then_llbb, else_llbb) };
384+
385+
self.cx.set_metadata_node(branch, llvm::MD_prof, &md);
386+
}
387+
362388
fn switch_with_weights(
363389
&mut self,
364390
v: Self::Value,

‎compiler/rustc_codegen_ssa/src/mir/block.rs‎

Lines changed: 149 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use rustc_session::config::OptLevel;
1515
use rustc_span::Span;
1616
use rustc_span::source_map::Spanned;
1717
use rustc_target::callconv::{ArgAbi, ArgAttributes, CastTarget, FnAbi, PassMode};
18+
use smallvec::SmallVec;
1819
use tracing::{debug, info};
1920

2021
use super::operand::OperandRef;
@@ -391,124 +392,66 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
391392
return;
392393
};
393394

394-
let mut target_iter = targets.iter();
395-
if target_iter.len() == 1 {
396-
// If there are two targets (one conditional, one fallback), emit `br` instead of
397-
// `switch`.
398-
let (test_value, target) = target_iter.next().unwrap();
399-
let otherwise = targets.otherwise();
400-
let lltarget = helper.llbb_with_cleanup(self, target);
401-
let llotherwise = helper.llbb_with_cleanup(self, otherwise);
402-
let target_cold = self.cold_blocks[target];
403-
let otherwise_cold = self.cold_blocks[otherwise];
404-
// If `target_cold == otherwise_cold`, the branches have the same weight
405-
// so there is no expectation. If they differ, the `target` branch is expected
406-
// when the `otherwise` branch is cold.
407-
let expect = if target_cold == otherwise_cold { None } else { Some(otherwise_cold) };
408-
if switch_ty == bx.tcx().types.bool {
409-
// Don't generate trivial icmps when switching on bool.
410-
match test_value {
411-
0 => {
412-
let expect = expect.map(|e| !e);
413-
bx.cond_br_with_expect(discr_value, llotherwise, lltarget, expect);
414-
}
415-
1 => {
416-
bx.cond_br_with_expect(discr_value, lltarget, llotherwise, expect);
417-
}
418-
_ => bug!(),
395+
let bool_ty = bx.tcx().types.bool;
396+
let normalized = self.normalize_switch_targets(targets, switch_ty == bool_ty);
397+
match normalized {
398+
NormalizedSwitch::BooleanBranch { then_bb, then_cold, else_bb, needs_trunc } => {
399+
let discr_value = if needs_trunc {
400+
let bool_llty = bx.immediate_backend_type(bx.layout_of(bool_ty));
401+
bx.unchecked_utrunc(discr_value, bool_llty)
402+
} else {
403+
discr_value
404+
};
405+
if let Some(then_cold) = then_cold {
406+
bx.cond_br_with_weight(
407+
discr_value,
408+
helper.llbb_with_cleanup(self, then_bb),
409+
helper.llbb_with_cleanup(self, else_bb),
410+
then_cold,
411+
)
412+
} else {
413+
bx.cond_br(
414+
discr_value,
415+
helper.llbb_with_cleanup(self, then_bb),
416+
helper.llbb_with_cleanup(self, else_bb),
417+
)
419418
}
420-
} else {
419+
}
420+
NormalizedSwitch::Branch { then_bb: (test_value, then_bb), then_cold, else_bb } => {
421421
let switch_llty = bx.immediate_backend_type(bx.layout_of(switch_ty));
422422
let llval = bx.const_uint_big(switch_llty, test_value);
423423
let cmp = bx.icmp(IntPredicate::IntEQ, discr_value, llval);
424-
bx.cond_br_with_expect(cmp, lltarget, llotherwise, expect);
424+
let lltarget = helper.llbb_with_cleanup(self, then_bb);
425+
let llotherwise = helper.llbb_with_cleanup(self, else_bb);
426+
427+
if let Some(then_cold) = then_cold {
428+
bx.cond_br_with_weight(cmp, lltarget, llotherwise, then_cold)
429+
} else {
430+
bx.cond_br(cmp, lltarget, llotherwise)
431+
}
425432
}
426-
} else if target_iter.len() == 2
427-
&& self.mir[targets.otherwise()].is_empty_unreachable()
428-
&& targets.all_values().contains(&Pu128(0))
429-
&& targets.all_values().contains(&Pu128(1))
430-
{
431-
// This is the really common case for `bool`, `Option`, etc.
432-
// By using `trunc nuw` we communicate that other values are
433-
// impossible without needing `switch` or `assume`s.
434-
let true_bb = targets.target_for_value(1);
435-
let false_bb = targets.target_for_value(0);
436-
let true_ll = helper.llbb_with_cleanup(self, true_bb);
437-
let false_ll = helper.llbb_with_cleanup(self, false_bb);
438-
439-
let expected_cond_value = if self.cx.sess().opts.optimize == OptLevel::No {
440-
None
441-
} else {
442-
match (self.cold_blocks[true_bb], self.cold_blocks[false_bb]) {
443-
// Same coldness, no expectation
444-
(true, true) | (false, false) => None,
445-
// Different coldness, expect the non-cold one
446-
(true, false) => Some(false),
447-
(false, true) => Some(true),
433+
NormalizedSwitch::Switch { values, targets, targets_cold } => {
434+
let (&else_bb, targets) = targets.split_last().unwrap();
435+
let else_llbb = helper.llbb_with_cleanup(self, else_bb);
436+
let cases = values
437+
.iter()
438+
.zip(targets)
439+
.map(|(&value, &target)| (value.get(), helper.llbb_with_cleanup(self, target)));
440+
if let Some(targets_cold) = targets_cold {
441+
let (&else_cold, targets_cold) = targets_cold.split_last().unwrap();
442+
bx.switch_with_weights(
443+
discr_value,
444+
else_llbb,
445+
else_cold,
446+
cases
447+
.zip(targets_cold)
448+
.map(|((value, target), &cold)| (value, target, cold)),
449+
);
450+
} else {
451+
bx.switch(discr_value, else_llbb, cases);
448452
}
449-
};
450-
451-
let bool_ty = bx.tcx().types.bool;
452-
let cond = if switch_ty == bool_ty {
453-
discr_value
454-
} else {
455-
let bool_llty = bx.immediate_backend_type(bx.layout_of(bool_ty));
456-
bx.unchecked_utrunc(discr_value, bool_llty)
457-
};
458-
bx.cond_br_with_expect(cond, true_ll, false_ll, expected_cond_value);
459-
} else if self.cx.sess().opts.optimize == OptLevel::No
460-
&& target_iter.len() == 2
461-
&& self.mir[targets.otherwise()].is_empty_unreachable()
462-
{
463-
// In unoptimized builds, if there are two normal targets and the `otherwise` target is
464-
// an unreachable BB, emit `br` instead of `switch`. This leaves behind the unreachable
465-
// BB, which will usually (but not always) be dead code.
466-
//
467-
// Why only in unoptimized builds?
468-
// - In unoptimized builds LLVM uses FastISel which does not support switches, so it
469-
// must fall back to the slower SelectionDAG isel. Therefore, using `br` gives
470-
// significant compile time speedups for unoptimized builds.
471-
// - In optimized builds the above doesn't hold, and using `br` sometimes results in
472-
// worse generated code because LLVM can no longer tell that the value being switched
473-
// on can only have two values, e.g. 0 and 1.
474-
//
475-
let (test_value1, target1) = target_iter.next().unwrap();
476-
let (_test_value2, target2) = target_iter.next().unwrap();
477-
let ll1 = helper.llbb_with_cleanup(self, target1);
478-
let ll2 = helper.llbb_with_cleanup(self, target2);
479-
let switch_llty = bx.immediate_backend_type(bx.layout_of(switch_ty));
480-
let llval = bx.const_uint_big(switch_llty, test_value1);
481-
let cmp = bx.icmp(IntPredicate::IntEQ, discr_value, llval);
482-
bx.cond_br(cmp, ll1, ll2);
483-
} else {
484-
let otherwise = targets.otherwise();
485-
let otherwise_cold = self.cold_blocks[otherwise];
486-
let otherwise_unreachable = self.mir[otherwise].is_empty_unreachable();
487-
let cold_count = targets.iter().filter(|(_, target)| self.cold_blocks[*target]).count();
488-
let none_cold = cold_count == 0;
489-
let all_cold = cold_count == targets.iter().len();
490-
if (none_cold && (!otherwise_cold || otherwise_unreachable))
491-
|| (all_cold && (otherwise_cold || otherwise_unreachable))
492-
{
493-
// All targets have the same weight,
494-
// or `otherwise` is unreachable and it's the only target with a different weight.
495-
bx.switch(
496-
discr_value,
497-
helper.llbb_with_cleanup(self, targets.otherwise()),
498-
target_iter
499-
.map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
500-
);
501-
} else {
502-
// Targets have different weights
503-
bx.switch_with_weights(
504-
discr_value,
505-
helper.llbb_with_cleanup(self, targets.otherwise()),
506-
otherwise_cold,
507-
target_iter.map(|(value, target)| {
508-
(value, helper.llbb_with_cleanup(self, target), self.cold_blocks[target])
509-
}),
510-
);
511453
}
454+
NormalizedSwitch::Jump { target } => bx.br(helper.llbb_with_cleanup(self, target)),
512455
}
513456
}
514457

@@ -2034,6 +1977,71 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
20341977
}
20351978
}
20361979
}
1980+
1981+
fn normalize_switch_targets(
1982+
&self,
1983+
targets: &SwitchTargets,
1984+
discr_boolean: bool,
1985+
) -> NormalizedSwitch {
1986+
let mut target_iter = targets.iter();
1987+
let target_iter_len = target_iter.len();
1988+
if target_iter_len == 0 {
1989+
return NormalizedSwitch::Jump { target: targets.otherwise() };
1990+
}
1991+
let mut targets_cold = targets.all_targets().iter().map(|&target| self.cold_blocks[target]);
1992+
let use_weights =
1993+
targets_cold.clone().any(|hot| hot) && targets_cold.clone().any(|hot| !hot);
1994+
if target_iter_len > 2
1995+
|| target_iter_len == 2 && !self.mir[targets.otherwise()].is_empty_unreachable()
1996+
{
1997+
return NormalizedSwitch::Switch {
1998+
values: targets.all_values().into(),
1999+
targets: targets.all_targets().into(),
2000+
targets_cold: use_weights.then(|| targets_cold.collect()),
2001+
};
2002+
}
2003+
2004+
let then_bb = target_iter.next().unwrap();
2005+
let then_cold = targets_cold.next().unwrap();
2006+
let else_bb =
2007+
if target_iter_len == 1 { targets.otherwise() } else { target_iter.next().unwrap().1 };
2008+
2009+
if discr_boolean {
2010+
let (test_value, then_bb) = then_bb;
2011+
let (then_bb, then_cold, else_bb) = if test_value == 1 {
2012+
(then_bb, then_cold, else_bb)
2013+
} else {
2014+
(else_bb, targets_cold.next().unwrap(), then_bb)
2015+
};
2016+
NormalizedSwitch::BooleanBranch {
2017+
needs_trunc: false,
2018+
then_bb,
2019+
then_cold: use_weights.then_some(then_cold),
2020+
else_bb,
2021+
}
2022+
} else if target_iter_len == 2
2023+
&& let &[Pu128(then_value), Pu128(else_value)] = targets.all_values()
2024+
&& ((then_value == 1 && else_value == 0) || (then_value == 0 && else_value == 1))
2025+
{
2026+
let (then_bb, then_cold, else_bb) = if then_value == 1 {
2027+
(then_bb.1, then_cold, else_bb)
2028+
} else {
2029+
(else_bb, targets_cold.next().unwrap(), then_bb.1)
2030+
};
2031+
NormalizedSwitch::BooleanBranch {
2032+
needs_trunc: true,
2033+
then_bb,
2034+
then_cold: use_weights.then_some(then_cold),
2035+
else_bb,
2036+
}
2037+
} else {
2038+
NormalizedSwitch::Branch {
2039+
then_bb,
2040+
then_cold: use_weights.then_some(then_cold),
2041+
else_bb,
2042+
}
2043+
}
2044+
}
20372045
}
20382046

20392047
enum ReturnDest<'tcx, V> {
@@ -2090,3 +2098,33 @@ pub fn store_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
20902098
bx.store(value, ptr, align);
20912099
};
20922100
}
2101+
2102+
/// Normalized version of [`SwitchTargets`].
2103+
enum NormalizedSwitch {
2104+
/// Discriminant is a bool with 1=>then, 0=>else
2105+
/// If `needs_trunc` is true, this is not a boolean but some other
2106+
/// value which may be 0 or 1, and therefore needs `trunc nuw`
2107+
BooleanBranch {
2108+
needs_trunc: bool,
2109+
then_bb: mir::BasicBlock,
2110+
then_cold: Option<bool>,
2111+
else_bb: mir::BasicBlock,
2112+
// else_cold is the inverse of then_cold
2113+
},
2114+
/// If discr==then_bb.0=>then, otherwise=>else
2115+
Branch {
2116+
then_bb: (u128, mir::BasicBlock),
2117+
then_cold: Option<bool>,
2118+
else_bb: mir::BasicBlock,
2119+
// else_cold is the inverse of then_cold
2120+
},
2121+
/// Equivalent to [`SwitchTargets`], but known to have at least 3 targets
2122+
Switch {
2123+
values: SmallVec<[Pu128; 2]>,
2124+
targets: SmallVec<[mir::BasicBlock; 3]>,
2125+
targets_cold: Option<SmallVec<[bool; 3]>>,
2126+
},
2127+
Jump {
2128+
target: mir::BasicBlock,
2129+
},
2130+
}

‎compiler/rustc_codegen_ssa/src/traits/builder.rs‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,21 @@ pub trait BuilderMethods<'a, 'tcx>:
103103
self.cond_br(cond, then_llbb, else_llbb)
104104
}
105105

106+
// Conditional with weights.
107+
//
108+
// This function is opt-in for back ends.
109+
//
110+
// The default implementation ignores `then_cold` and calls `self.cond_br()`
111+
fn cond_br_with_weight(
112+
&mut self,
113+
cond: Self::Value,
114+
then_llbb: Self::BasicBlock,
115+
else_llbb: Self::BasicBlock,
116+
_then_cold: bool,
117+
) {
118+
self.cond_br(cond, then_llbb, else_llbb)
119+
}
120+
106121
fn switch(
107122
&mut self,
108123
v: Self::Value,

0 commit comments

Comments
 (0)