8585//! that contain `AllocId`s.
8686
8787use std:: borrow:: Cow ;
88+ use std:: cmp:: Ordering ;
8889use std:: hash:: { Hash , Hasher } ;
8990
9091use either:: Either ;
9192use hashbrown:: hash_table:: { Entry , HashTable } ;
92- use itertools:: Itertools as _;
93+ use itertools:: { Itertools as _, MinMaxResult } ;
9394use rustc_abi:: { self as abi, BackendRepr , FIRST_VARIANT , FieldIdx , Primitive , Size , VariantIdx } ;
9495use rustc_arena:: DroplessArena ;
9596use rustc_const_eval:: const_eval:: DummyMachine ;
@@ -107,6 +108,7 @@ use rustc_middle::mir::interpret::GlobalAlloc;
107108use rustc_middle:: mir:: visit:: * ;
108109use rustc_middle:: mir:: * ;
109110use rustc_middle:: ty:: layout:: HasTypingEnv ;
111+ use rustc_middle:: ty:: util:: IntTypeExt ;
110112use rustc_middle:: ty:: { self , Ty , TyCtxt } ;
111113use rustc_span:: DUMMY_SP ;
112114use smallvec:: SmallVec ;
@@ -1367,17 +1369,18 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
13671369 }
13681370 }
13691371
1370- if let Some ( value) = self . simplify_binary_inner ( op, lhs_ty, lhs, rhs) {
1372+ let ty = op. ty ( self . tcx , lhs_ty, self . ty ( rhs) ) ;
1373+ if let Some ( value) = self . simplify_binary_inner ( op, ty, lhs_ty, lhs, rhs) {
13711374 return Some ( value) ;
13721375 }
1373- let ty = op. ty ( self . tcx , lhs_ty, self . ty ( rhs) ) ;
13741376 let value = Value :: BinaryOp ( op, lhs, rhs) ;
13751377 Some ( self . insert ( ty, value) )
13761378 }
13771379
13781380 fn simplify_binary_inner (
13791381 & mut self ,
13801382 op : BinOp ,
1383+ ty : Ty < ' tcx > ,
13811384 lhs_ty : Ty < ' tcx > ,
13821385 lhs : VnIndex ,
13831386 rhs : VnIndex ,
@@ -1403,9 +1406,9 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14031406 } ;
14041407
14051408 // Represent the values as `Left(bits)` or `Right(VnIndex)`.
1406- use Either :: { Left , Right } ;
1407- let a = as_bits ( lhs) . map_or ( Right ( lhs) , Left ) ;
1408- let b = as_bits ( rhs) . map_or ( Right ( rhs) , Left ) ;
1409+ use BitsOrIndex :: * ;
1410+ let a = as_bits ( lhs) . map_or ( Value ( lhs) , Bits ) ;
1411+ let b = as_bits ( rhs) . map_or ( Value ( rhs) , Bits ) ;
14091412
14101413 let result = match ( op, a, b) {
14111414 // Neutral elements.
@@ -1415,8 +1418,8 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14151418 | BinOp :: AddUnchecked
14161419 | BinOp :: BitOr
14171420 | BinOp :: BitXor ,
1418- Left ( 0 ) ,
1419- Right ( p) ,
1421+ Bits ( 0 ) ,
1422+ Value ( p) ,
14201423 )
14211424 | (
14221425 BinOp :: Add
@@ -1430,17 +1433,17 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14301433 | BinOp :: Offset
14311434 | BinOp :: Shl
14321435 | BinOp :: Shr ,
1433- Right ( p) ,
1434- Left ( 0 ) ,
1436+ Value ( p) ,
1437+ Bits ( 0 ) ,
14351438 )
1436- | ( BinOp :: Mul | BinOp :: MulWithOverflow | BinOp :: MulUnchecked , Left ( 1 ) , Right ( p) )
1439+ | ( BinOp :: Mul | BinOp :: MulWithOverflow | BinOp :: MulUnchecked , Bits ( 1 ) , Value ( p) )
14371440 | (
14381441 BinOp :: Mul | BinOp :: MulWithOverflow | BinOp :: MulUnchecked | BinOp :: Div ,
1439- Right ( p) ,
1440- Left ( 1 ) ,
1442+ Value ( p) ,
1443+ Bits ( 1 ) ,
14411444 ) => p,
14421445 // Attempt to simplify `x & ALL_ONES` to `x`, with `ALL_ONES` depending on type size.
1443- ( BinOp :: BitAnd , Right ( p) , Left ( ones) ) | ( BinOp :: BitAnd , Left ( ones) , Right ( p) )
1446+ ( BinOp :: BitAnd , Value ( p) , Bits ( ones) ) | ( BinOp :: BitAnd , Bits ( ones) , Value ( p) )
14441447 if ones == layout. size . truncate ( u128:: MAX )
14451448 || ( layout. ty . is_bool ( ) && ones == 1 ) =>
14461449 {
@@ -1450,9 +1453,9 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14501453 (
14511454 BinOp :: Mul | BinOp :: MulWithOverflow | BinOp :: MulUnchecked | BinOp :: BitAnd ,
14521455 _,
1453- Left ( 0 ) ,
1456+ Bits ( 0 ) ,
14541457 )
1455- | ( BinOp :: Rem , _, Left ( 1 ) )
1458+ | ( BinOp :: Rem , _, Bits ( 1 ) )
14561459 | (
14571460 BinOp :: Mul
14581461 | BinOp :: MulWithOverflow
@@ -1462,11 +1465,11 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14621465 | BinOp :: BitAnd
14631466 | BinOp :: Shl
14641467 | BinOp :: Shr ,
1465- Left ( 0 ) ,
1468+ Bits ( 0 ) ,
14661469 _,
14671470 ) => self . insert_scalar ( lhs_ty, Scalar :: from_uint ( 0u128 , layout. size ) ) ,
14681471 // Attempt to simplify `x | ALL_ONES` to `ALL_ONES`.
1469- ( BinOp :: BitOr , _, Left ( ones) ) | ( BinOp :: BitOr , Left ( ones) , _)
1472+ ( BinOp :: BitOr , _, Bits ( ones) ) | ( BinOp :: BitOr , Bits ( ones) , _)
14701473 if ones == layout. size . truncate ( u128:: MAX )
14711474 || ( layout. ty . is_bool ( ) && ones == 1 ) =>
14721475 {
@@ -1482,11 +1485,19 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14821485 // - if both operands can be computed as bits, just compare the bits;
14831486 // - if we proved that both operands have the same value, we can insert true/false;
14841487 // - otherwise, do nothing, as we do not try to prove inequality.
1485- ( BinOp :: Eq , Left ( a) , Left ( b) ) => self . insert_bool ( a == b) ,
1488+ ( BinOp :: Eq , Bits ( a) , Bits ( b) ) => self . insert_bool ( a == b) ,
14861489 ( BinOp :: Eq , a, b) if a == b => self . insert_bool ( true ) ,
1487- ( BinOp :: Ne , Left ( a) , Left ( b) ) => self . insert_bool ( a != b) ,
1490+ ( BinOp :: Ne , Bits ( a) , Bits ( b) ) => self . insert_bool ( a != b) ,
14881491 ( BinOp :: Ne , a, b) if a == b => self . insert_bool ( false ) ,
1489- _ => return None ,
1492+ // If we know the range of the value, we can compare them.
1493+ ( BinOp :: Lt | BinOp :: Le | BinOp :: Gt | BinOp :: Ge | BinOp :: Cmp , lhs, rhs)
1494+ if let Some ( result) = self . simplify_binary_range ( op, lhs_ty, lhs, rhs) =>
1495+ {
1496+ self . insert_scalar ( ty, result)
1497+ }
1498+ _ => {
1499+ return None ;
1500+ }
14901501 } ;
14911502
14921503 if op. is_overflowing ( ) {
@@ -1498,6 +1509,81 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
14981509 }
14991510 }
15001511
1512+ fn simplify_binary_range (
1513+ & self ,
1514+ op : BinOp ,
1515+ lhs_ty : Ty < ' tcx > ,
1516+ lhs : BitsOrIndex ,
1517+ rhs : BitsOrIndex ,
1518+ ) -> Option < Scalar > {
1519+ if !lhs_ty. is_integral ( ) {
1520+ return None ;
1521+ }
1522+ let layout = self . ecx . layout_of ( lhs_ty) . ok ( ) ?;
1523+ let range = |val : BitsOrIndex | match val {
1524+ BitsOrIndex :: Bits ( bits) => {
1525+ let value = ImmTy :: from_uint ( bits, layout) ;
1526+ Some ( Either :: Left ( value) )
1527+ }
1528+ BitsOrIndex :: Value ( value) => {
1529+ let value = if let Value :: Cast { kind : CastKind :: IntToInt , value } = self . get ( value)
1530+ {
1531+ value
1532+ } else {
1533+ value
1534+ } ;
1535+ let Value :: Discriminant ( discr) = self . get ( value) else {
1536+ return None ;
1537+ } ;
1538+ let ty:: Adt ( adt, _) = self . ty ( discr) . kind ( ) else {
1539+ return None ;
1540+ } ;
1541+ if !adt. is_enum ( ) {
1542+ return None ;
1543+ }
1544+ let discr_ty = adt. repr ( ) . discr_type ( ) . to_ty ( self . tcx ) ;
1545+ let discr_layout = self . ecx . layout_of ( discr_ty) . ok ( ) ?;
1546+ let MinMaxResult :: MinMax ( min, max) = adt
1547+ . discriminants ( self . tcx )
1548+ . map ( |( _, discr) | {
1549+ let val = ImmTy :: from_uint ( discr. val , discr_layout) ;
1550+ let val = self . ecx . int_to_int_or_float ( & val, layout) . discard_err ( ) . unwrap ( ) ;
1551+ val
1552+ } )
1553+ . minmax_by ( |x, y| {
1554+ let cmp = self . ecx . binary_op ( BinOp :: Cmp , x, y) . unwrap ( ) ;
1555+ let cmp = cmp. to_scalar_int ( ) . unwrap ( ) . to_i8 ( ) ;
1556+ match cmp {
1557+ -1 => Ordering :: Less ,
1558+ 0 => Ordering :: Equal ,
1559+ 1 => Ordering :: Greater ,
1560+ _ => unreachable ! ( ) ,
1561+ }
1562+ } )
1563+ else {
1564+ return None ;
1565+ } ;
1566+ Some ( Either :: Right ( ( min, max) ) )
1567+ }
1568+ } ;
1569+ let lhs = range ( lhs) ?;
1570+ let rhs = range ( rhs) ?;
1571+ match ( lhs, rhs) {
1572+ ( Either :: Left ( lhs) , Either :: Right ( ( rhs_min, rhs_max) ) ) => {
1573+ let cmp_min = self . ecx . binary_op ( op, & lhs, & rhs_min) . discard_err ( ) ?. to_scalar ( ) ;
1574+ let cmp_max = self . ecx . binary_op ( op, & lhs, & rhs_max) . discard_err ( ) ?. to_scalar ( ) ;
1575+ if cmp_min == cmp_max { Some ( cmp_min) } else { None }
1576+ }
1577+ ( Either :: Right ( ( lhs_min, lhs_max) ) , Either :: Left ( rhs) ) => {
1578+ let cmp_min = self . ecx . binary_op ( op, & lhs_min, & rhs) . discard_err ( ) ?. to_scalar ( ) ;
1579+ let cmp_max = self . ecx . binary_op ( op, & lhs_max, & rhs) . discard_err ( ) ?. to_scalar ( ) ;
1580+ if cmp_min == cmp_max { Some ( cmp_min) } else { None }
1581+ }
1582+ ( Either :: Left ( _) , Either :: Left ( _) ) => None ,
1583+ ( Either :: Right ( _) , Either :: Right ( _) ) => None ,
1584+ }
1585+ }
1586+
15011587 fn simplify_cast (
15021588 & mut self ,
15031589 initial_kind : & mut CastKind ,
@@ -1960,3 +2046,9 @@ impl<'tcx> MutVisitor<'tcx> for StorageRemover<'tcx> {
19602046 }
19612047 }
19622048}
2049+
2050+ #[ derive( Debug , PartialEq , Clone , Copy ) ]
2051+ enum BitsOrIndex {
2052+ Bits ( u128 ) ,
2053+ Value ( VnIndex ) ,
2054+ }
0 commit comments