Skip to content

Commit 37bc379

Browse files
Auto merge of #148443 - dianqk:gvn-cmp-range, r=<try>
GVN: Compare discriminants with constant
2 parents 90b6588 + 578a817 commit 37bc379

File tree

1 file changed

+113
-21
lines changed
  • compiler/rustc_mir_transform/src

1 file changed

+113
-21
lines changed

‎compiler/rustc_mir_transform/src/gvn.rs‎

Lines changed: 113 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,12 @@
8585
//! that contain `AllocId`s.
8686
8787
use std::borrow::Cow;
88+
use std::cmp::Ordering;
8889
use std::hash::{Hash, Hasher};
8990

9091
use either::Either;
9192
use hashbrown::hash_table::{Entry, HashTable};
92-
use itertools::Itertools as _;
93+
use itertools::{Itertools as _, MinMaxResult};
9394
use rustc_abi::{self as abi, BackendRepr, FIRST_VARIANT, FieldIdx, Primitive, Size, VariantIdx};
9495
use rustc_arena::DroplessArena;
9596
use rustc_const_eval::const_eval::DummyMachine;
@@ -107,6 +108,7 @@ use rustc_middle::mir::interpret::GlobalAlloc;
107108
use rustc_middle::mir::visit::*;
108109
use rustc_middle::mir::*;
109110
use rustc_middle::ty::layout::HasTypingEnv;
111+
use rustc_middle::ty::util::IntTypeExt;
110112
use rustc_middle::ty::{self, Ty, TyCtxt};
111113
use rustc_span::DUMMY_SP;
112114
use smallvec::SmallVec;
@@ -1367,17 +1369,18 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
13671369
}
13681370
}
13691371

1370-
if let Some(value) = self.simplify_binary_inner(op, lhs_ty, lhs, rhs) {
1372+
let ty = op.ty(self.tcx, lhs_ty, self.ty(rhs));
1373+
if let Some(value) = self.simplify_binary_inner(op, ty, lhs_ty, lhs, rhs) {
13711374
return Some(value);
13721375
}
1373-
let ty = op.ty(self.tcx, lhs_ty, self.ty(rhs));
13741376
let value = Value::BinaryOp(op, lhs, rhs);
13751377
Some(self.insert(ty, value))
13761378
}
13771379

13781380
fn simplify_binary_inner(
13791381
&mut self,
13801382
op: BinOp,
1383+
ty: Ty<'tcx>,
13811384
lhs_ty: Ty<'tcx>,
13821385
lhs: VnIndex,
13831386
rhs: VnIndex,
@@ -1403,9 +1406,9 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14031406
};
14041407

14051408
// Represent the values as `Left(bits)` or `Right(VnIndex)`.
1406-
use Either::{Left, Right};
1407-
let a = as_bits(lhs).map_or(Right(lhs), Left);
1408-
let b = as_bits(rhs).map_or(Right(rhs), Left);
1409+
use BitsOrIndex::*;
1410+
let a = as_bits(lhs).map_or(Value(lhs), Bits);
1411+
let b = as_bits(rhs).map_or(Value(rhs), Bits);
14091412

14101413
let result = match (op, a, b) {
14111414
// Neutral elements.
@@ -1415,8 +1418,8 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14151418
| BinOp::AddUnchecked
14161419
| BinOp::BitOr
14171420
| BinOp::BitXor,
1418-
Left(0),
1419-
Right(p),
1421+
Bits(0),
1422+
Value(p),
14201423
)
14211424
| (
14221425
BinOp::Add
@@ -1430,17 +1433,17 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14301433
| BinOp::Offset
14311434
| BinOp::Shl
14321435
| BinOp::Shr,
1433-
Right(p),
1434-
Left(0),
1436+
Value(p),
1437+
Bits(0),
14351438
)
1436-
| (BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked, Left(1), Right(p))
1439+
| (BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked, Bits(1), Value(p))
14371440
| (
14381441
BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked | BinOp::Div,
1439-
Right(p),
1440-
Left(1),
1442+
Value(p),
1443+
Bits(1),
14411444
) => p,
14421445
// Attempt to simplify `x & ALL_ONES` to `x`, with `ALL_ONES` depending on type size.
1443-
(BinOp::BitAnd, Right(p), Left(ones)) | (BinOp::BitAnd, Left(ones), Right(p))
1446+
(BinOp::BitAnd, Value(p), Bits(ones)) | (BinOp::BitAnd, Bits(ones), Value(p))
14441447
if ones == layout.size.truncate(u128::MAX)
14451448
|| (layout.ty.is_bool() && ones == 1) =>
14461449
{
@@ -1450,9 +1453,9 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14501453
(
14511454
BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked | BinOp::BitAnd,
14521455
_,
1453-
Left(0),
1456+
Bits(0),
14541457
)
1455-
| (BinOp::Rem, _, Left(1))
1458+
| (BinOp::Rem, _, Bits(1))
14561459
| (
14571460
BinOp::Mul
14581461
| BinOp::MulWithOverflow
@@ -1462,11 +1465,11 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14621465
| BinOp::BitAnd
14631466
| BinOp::Shl
14641467
| BinOp::Shr,
1465-
Left(0),
1468+
Bits(0),
14661469
_,
14671470
) => self.insert_scalar(lhs_ty, Scalar::from_uint(0u128, layout.size)),
14681471
// Attempt to simplify `x | ALL_ONES` to `ALL_ONES`.
1469-
(BinOp::BitOr, _, Left(ones)) | (BinOp::BitOr, Left(ones), _)
1472+
(BinOp::BitOr, _, Bits(ones)) | (BinOp::BitOr, Bits(ones), _)
14701473
if ones == layout.size.truncate(u128::MAX)
14711474
|| (layout.ty.is_bool() && ones == 1) =>
14721475
{
@@ -1482,11 +1485,19 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14821485
// - if both operands can be computed as bits, just compare the bits;
14831486
// - if we proved that both operands have the same value, we can insert true/false;
14841487
// - otherwise, do nothing, as we do not try to prove inequality.
1485-
(BinOp::Eq, Left(a), Left(b)) => self.insert_bool(a == b),
1488+
(BinOp::Eq, Bits(a), Bits(b)) => self.insert_bool(a == b),
14861489
(BinOp::Eq, a, b) if a == b => self.insert_bool(true),
1487-
(BinOp::Ne, Left(a), Left(b)) => self.insert_bool(a != b),
1490+
(BinOp::Ne, Bits(a), Bits(b)) => self.insert_bool(a != b),
14881491
(BinOp::Ne, a, b) if a == b => self.insert_bool(false),
1489-
_ => return None,
1492+
// If we know the range of the value, we can compare them.
1493+
(BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge | BinOp::Cmp, lhs, rhs)
1494+
if let Some(result) = self.simplify_binary_range(op, lhs_ty, lhs, rhs) =>
1495+
{
1496+
self.insert_scalar(ty, result)
1497+
}
1498+
_ => {
1499+
return None;
1500+
}
14901501
};
14911502

14921503
if op.is_overflowing() {
@@ -1498,6 +1509,81 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14981509
}
14991510
}
15001511

1512+
fn simplify_binary_range(
1513+
&self,
1514+
op: BinOp,
1515+
lhs_ty: Ty<'tcx>,
1516+
lhs: BitsOrIndex,
1517+
rhs: BitsOrIndex,
1518+
) -> Option<Scalar> {
1519+
if !lhs_ty.is_integral() {
1520+
return None;
1521+
}
1522+
let layout = self.ecx.layout_of(lhs_ty).ok()?;
1523+
let range = |val: BitsOrIndex| match val {
1524+
BitsOrIndex::Bits(bits) => {
1525+
let value = ImmTy::from_uint(bits, layout);
1526+
Some(Either::Left(value))
1527+
}
1528+
BitsOrIndex::Value(value) => {
1529+
let value = if let Value::Cast { kind: CastKind::IntToInt, value } = self.get(value)
1530+
{
1531+
value
1532+
} else {
1533+
value
1534+
};
1535+
let Value::Discriminant(discr) = self.get(value) else {
1536+
return None;
1537+
};
1538+
let ty::Adt(adt, _) = self.ty(discr).kind() else {
1539+
return None;
1540+
};
1541+
if !adt.is_enum() {
1542+
return None;
1543+
}
1544+
let discr_ty = adt.repr().discr_type().to_ty(self.tcx);
1545+
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
1546+
let MinMaxResult::MinMax(min, max) = adt
1547+
.discriminants(self.tcx)
1548+
.map(|(_, discr)| {
1549+
let val = ImmTy::from_uint(discr.val, discr_layout);
1550+
let val = self.ecx.int_to_int_or_float(&val, layout).discard_err().unwrap();
1551+
val
1552+
})
1553+
.minmax_by(|x, y| {
1554+
let cmp = self.ecx.binary_op(BinOp::Cmp, x, y).unwrap();
1555+
let cmp = cmp.to_scalar_int().unwrap().to_i8();
1556+
match cmp {
1557+
-1 => Ordering::Less,
1558+
0 => Ordering::Equal,
1559+
1 => Ordering::Greater,
1560+
_ => unreachable!(),
1561+
}
1562+
})
1563+
else {
1564+
return None;
1565+
};
1566+
Some(Either::Right((min, max)))
1567+
}
1568+
};
1569+
let lhs = range(lhs)?;
1570+
let rhs = range(rhs)?;
1571+
match (lhs, rhs) {
1572+
(Either::Left(lhs), Either::Right((rhs_min, rhs_max))) => {
1573+
let cmp_min = self.ecx.binary_op(op, &lhs, &rhs_min).discard_err()?.to_scalar();
1574+
let cmp_max = self.ecx.binary_op(op, &lhs, &rhs_max).discard_err()?.to_scalar();
1575+
if cmp_min == cmp_max { Some(cmp_min) } else { None }
1576+
}
1577+
(Either::Right((lhs_min, lhs_max)), Either::Left(rhs)) => {
1578+
let cmp_min = self.ecx.binary_op(op, &lhs_min, &rhs).discard_err()?.to_scalar();
1579+
let cmp_max = self.ecx.binary_op(op, &lhs_max, &rhs).discard_err()?.to_scalar();
1580+
if cmp_min == cmp_max { Some(cmp_min) } else { None }
1581+
}
1582+
(Either::Left(_), Either::Left(_)) => None,
1583+
(Either::Right(_), Either::Right(_)) => None,
1584+
}
1585+
}
1586+
15011587
fn simplify_cast(
15021588
&mut self,
15031589
initial_kind: &mut CastKind,
@@ -1960,3 +2046,9 @@ impl<'tcx> MutVisitor<'tcx> for StorageRemover<'tcx> {
19602046
}
19612047
}
19622048
}
2049+
2050+
#[derive(Debug, PartialEq, Clone, Copy)]
2051+
enum BitsOrIndex {
2052+
Bits(u128),
2053+
Value(VnIndex),
2054+
}

0 commit comments

Comments
 (0)