Skip to content

Commit

Permalink
Auto merge of rust-lang#134424 - 1c3t3a:null-checks, r=saethlin
Browse files Browse the repository at this point in the history
Insert null checks for pointer dereferences when debug assertions are enabled

Similar to how the alignment is already checked, this adds a check
for null pointer dereferences in debug mode. It is implemented similarly
to the alignment check as a `MirPass`.

This inserts checks in the same places as the `CheckAlignment` pass and additionally
also inserts checks for `Borrows`, so code like
```rust
let ptr: *const u32 = std::ptr::null();
let val: &u32 = unsafe { &*ptr };
```
will have a check inserted on dereference. This is done because null references
are UB. The alignment check doesn't cover these places, because in `&(*ptr).field`,
the exact requirement is that the final reference must be aligned. This is something to
consider further enhancements of the alignment check.

For now this is implemented as a separate `MirPass`, to make it easy to disable
this check if necessary.

This is related to a 2025H1 project goal for better UB checks in debug
mode: rust-lang/rust-project-goals#177.

r? `@saethlin`
  • Loading branch information
bors committed Jan 31, 2025
2 parents 6c1d960 + 67f4824 commit f89ffc1
Show file tree
Hide file tree
Showing 32 changed files with 551 additions and 165 deletions.
10 changes: 10 additions & 0 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,16 @@ fn codegen_fn_body(fx: &mut FunctionCx<'_, '_, '_>, start_block: Block) {
Some(source_info.span),
);
}
AssertKind::NullPointerDereference => {
let location = fx.get_caller_location(source_info).load_scalar(fx);

codegen_panic_inner(
fx,
rustc_hir::LangItem::PanicNullPointerDereference,
&[location],
Some(source_info.span),
)
}
_ => {
let location = fx.get_caller_location(source_info).load_scalar(fx);

Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,11 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
// and `#[track_caller]` adds an implicit third argument.
(LangItem::PanicMisalignedPointerDereference, vec![required, found, location])
}
AssertKind::NullPointerDereference => {
// It's `fn panic_null_pointer_dereference()`,
// `#[track_caller]` adds an implicit argument.
(LangItem::PanicNullPointerDereference, vec![location])
}
_ => {
// It's `pub fn panic_...()` and `#[track_caller]` adds an implicit argument.
(msg.panic_function(), vec![location])
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_const_eval/src/const_eval/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ impl<'tcx> interpret::Machine<'tcx> for CompileTimeMachine<'tcx> {
found: eval_to_int(found)?,
}
}
NullPointerDereference => NullPointerDereference,
};
Err(ConstEvalErrKind::AssertFailure(err)).into()
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ language_item_table! {
PanicAsyncFnResumedPanic, sym::panic_const_async_fn_resumed_panic, panic_const_async_fn_resumed_panic, Target::Fn, GenericRequirement::None;
PanicAsyncGenFnResumedPanic, sym::panic_const_async_gen_fn_resumed_panic, panic_const_async_gen_fn_resumed_panic, Target::Fn, GenericRequirement::None;
PanicGenFnNonePanic, sym::panic_const_gen_fn_none_panic, panic_const_gen_fn_none_panic, Target::Fn, GenericRequirement::None;
PanicNullPointerDereference, sym::panic_null_pointer_dereference, panic_null_pointer_dereference, Target::Fn, GenericRequirement::None;
/// libstd panic entry point. Necessary for const eval to be able to catch it
BeginPanic, sym::begin_panic, begin_panic_fn, Target::Fn, GenericRequirement::None;

Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_middle/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ middle_assert_gen_resume_after_panic = `gen` fn or block cannot be further itera
middle_assert_misaligned_ptr_deref =
misaligned pointer dereference: address must be a multiple of {$required} but is {$found}
middle_assert_null_ptr_deref =
null pointer dereference occurred
middle_assert_op_overflow =
attempt to compute `{$left} {$op} {$right}`, which would overflow
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/mir/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,7 @@ pub enum AssertKind<O> {
ResumedAfterReturn(CoroutineKind),
ResumedAfterPanic(CoroutineKind),
MisalignedPointerDereference { required: O, found: O },
NullPointerDereference,
}

#[derive(Clone, Debug, PartialEq, TyEncodable, TyDecodable, Hash, HashStable)]
Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_middle/src/mir/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ impl<O> AssertKind<O> {
ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => {
LangItem::PanicGenFnNonePanic
}
NullPointerDereference => LangItem::PanicNullPointerDereference,

BoundsCheck { .. } | MisalignedPointerDereference { .. } => {
bug!("Unexpected AssertKind")
Expand Down Expand Up @@ -271,6 +272,7 @@ impl<O> AssertKind<O> {
"\"misaligned pointer dereference: address must be a multiple of {{}} but is {{}}\", {required:?}, {found:?}"
)
}
NullPointerDereference => write!(f, "\"null pointer dereference occured\""),
ResumedAfterReturn(CoroutineKind::Coroutine(_)) => {
write!(f, "\"coroutine resumed after completion\"")
}
Expand Down Expand Up @@ -341,7 +343,7 @@ impl<O> AssertKind<O> {
ResumedAfterPanic(CoroutineKind::Coroutine(_)) => {
middle_assert_coroutine_resume_after_panic
}

NullPointerDereference => middle_assert_null_ptr_deref,
MisalignedPointerDereference { .. } => middle_assert_misaligned_ptr_deref,
}
}
Expand Down Expand Up @@ -374,7 +376,7 @@ impl<O> AssertKind<O> {
add!("left", format!("{left:#?}"));
add!("right", format!("{right:#?}"));
}
ResumedAfterReturn(_) | ResumedAfterPanic(_) => {}
ResumedAfterReturn(_) | ResumedAfterPanic(_) | NullPointerDereference => {}
MisalignedPointerDereference { required, found } => {
add!("required", format!("{required:#?}"));
add!("found", format!("{found:#?}"));
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/mir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ macro_rules! make_mir_visitor {
OverflowNeg(op) | DivisionByZero(op) | RemainderByZero(op) => {
self.visit_operand(op, location);
}
ResumedAfterReturn(_) | ResumedAfterPanic(_) => {
ResumedAfterReturn(_) | ResumedAfterPanic(_) | NullPointerDereference => {
// Nothing to visit
}
MisalignedPointerDereference { required, found } => {
Expand Down
198 changes: 38 additions & 160 deletions compiler/rustc_mir_transform/src/check_alignment.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use rustc_hir::lang_items::LangItem;
use rustc_index::IndexVec;
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{Ty, TyCtxt};
use rustc_session::Session;
use tracing::{debug, trace};

use crate::check_pointers::{BorrowCheckMode, PointerCheck, check_pointers};

pub(super) struct CheckAlignment;

Expand All @@ -19,166 +18,53 @@ impl<'tcx> crate::MirPass<'tcx> for CheckAlignment {
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// This pass emits new panics. If for whatever reason we do not have a panic
// implementation, running this pass may cause otherwise-valid code to not compile.
if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
return;
}

let typing_env = body.typing_env(tcx);
let basic_blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;

// This pass inserts new blocks. Each insertion changes the Location for all
// statements/blocks after. Iterating or visiting the MIR in order would require updating
// our current location after every insertion. By iterating backwards, we dodge this issue:
// The only Locations that an insertion changes have already been handled.
for block in (0..basic_blocks.len()).rev() {
let block = block.into();
for statement_index in (0..basic_blocks[block].statements.len()).rev() {
let location = Location { block, statement_index };
let statement = &basic_blocks[block].statements[statement_index];
let source_info = statement.source_info;

let mut finder =
PointerFinder { tcx, local_decls, typing_env, pointers: Vec::new() };
finder.visit_statement(statement, location);

for (local, ty) in finder.pointers {
debug!("Inserting alignment check for {:?}", ty);
let new_block = split_block(basic_blocks, location);
insert_alignment_check(
tcx,
local_decls,
&mut basic_blocks[block],
local,
ty,
source_info,
new_block,
);
}
}
}
// Skip trivially aligned place types.
let excluded_pointees = [tcx.types.bool, tcx.types.i8, tcx.types.u8];

// We have to exclude borrows here: in `&x.field`, the exact
// requirement is that the final reference must be aligned, but
// `check_pointers` would check that `x` is aligned, which would be wrong.
check_pointers(
tcx,
body,
&excluded_pointees,
insert_alignment_check,
BorrowCheckMode::ExcludeBorrows,
);
}

fn is_required(&self) -> bool {
true
}
}

struct PointerFinder<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
local_decls: &'a mut LocalDecls<'tcx>,
typing_env: ty::TypingEnv<'tcx>,
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
}

impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
// We want to only check reads and writes to Places, so we specifically exclude
// Borrow and RawBorrow.
match context {
PlaceContext::MutatingUse(
MutatingUseContext::Store
| MutatingUseContext::AsmOutput
| MutatingUseContext::Call
| MutatingUseContext::Yield
| MutatingUseContext::Drop,
) => {}
PlaceContext::NonMutatingUse(
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
) => {}
_ => {
return;
}
}

if !place.is_indirect() {
return;
}

// Since Deref projections must come first and only once, the pointer for an indirect place
// is the Local that the Place is based on.
let pointer = Place::from(place.local);
let pointer_ty = self.local_decls[place.local].ty;

// We only want to check places based on unsafe pointers
if !pointer_ty.is_unsafe_ptr() {
trace!("Indirect, but not based on an unsafe ptr, not checking {:?}", place);
return;
}

let pointee_ty =
pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer");
// Ideally we'd support this in the future, but for now we are limited to sized types.
if !pointee_ty.is_sized(self.tcx, self.typing_env) {
debug!("Unsafe pointer, but pointee is not known to be sized: {:?}", pointer_ty);
return;
}

// Try to detect types we are sure have an alignment of 1 and skip the check
// We don't need to look for str and slices, we already rejected unsized types above
let element_ty = match pointee_ty.kind() {
ty::Array(ty, _) => *ty,
_ => pointee_ty,
};
if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8].contains(&element_ty) {
debug!("Trivially aligned place type: {:?}", pointee_ty);
return;
}

// Ensure that this place is based on an aligned pointer.
self.pointers.push((pointer, pointee_ty));

self.super_place(place, context, location);
}
}

fn split_block(
basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
location: Location,
) -> BasicBlock {
let block_data = &mut basic_blocks[location.block];

// Drain every statement after this one and move the current terminator to a new basic block
let new_block = BasicBlockData {
statements: block_data.statements.split_off(location.statement_index),
terminator: block_data.terminator.take(),
is_cleanup: block_data.is_cleanup,
};

basic_blocks.push(new_block)
}

/// Inserts the actual alignment check's logic. Returns a
/// [AssertKind::MisalignedPointerDereference] on failure.
fn insert_alignment_check<'tcx>(
tcx: TyCtxt<'tcx>,
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
block_data: &mut BasicBlockData<'tcx>,
pointer: Place<'tcx>,
pointee_ty: Ty<'tcx>,
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
stmts: &mut Vec<Statement<'tcx>>,
source_info: SourceInfo,
new_block: BasicBlock,
) {
// Cast the pointer to a *const ()
) -> PointerCheck<'tcx> {
// Cast the pointer to a *const ().
let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit);
let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr);
let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into();
block_data
.statements
stmts
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) });

// Transmute the pointer to a usize (equivalent to `ptr.addr()`)
// Transmute the pointer to a usize (equivalent to `ptr.addr()`).
let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize);
let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
block_data
.statements
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });

// Get the alignment of the pointee
let alignment =
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty);
block_data.statements.push(Statement {
stmts.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((alignment, rvalue))),
});
Expand All @@ -191,7 +77,7 @@ fn insert_alignment_check<'tcx>(
user_ty: None,
const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)), tcx.types.usize),
}));
block_data.statements.push(Statement {
stmts.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((
alignment_mask,
Expand All @@ -202,7 +88,7 @@ fn insert_alignment_check<'tcx>(
// BitAnd the alignment mask with the pointer
let alignment_bits =
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
block_data.statements.push(Statement {
stmts.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((
alignment_bits,
Expand All @@ -220,29 +106,21 @@ fn insert_alignment_check<'tcx>(
user_ty: None,
const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize),
}));
block_data.statements.push(Statement {
stmts.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((
is_ok,
Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))),
))),
});

// Set this block's terminator to our assert, continuing to new_block if we pass
block_data.terminator = Some(Terminator {
source_info,
kind: TerminatorKind::Assert {
cond: Operand::Copy(is_ok),
expected: true,
target: new_block,
msg: Box::new(AssertKind::MisalignedPointerDereference {
required: Operand::Copy(alignment),
found: Operand::Copy(addr),
}),
// This calls panic_misaligned_pointer_dereference, which is #[rustc_nounwind].
// We never want to insert an unwind into unsafe code, because unwinding could
// make a failing UB check turn into much worse UB when we start unwinding.
unwind: UnwindAction::Unreachable,
},
});
// Emit a check that asserts on the alignment and otherwise triggers a
// AssertKind::MisalignedPointerDereference.
PointerCheck {
cond: Operand::Copy(is_ok),
assert_kind: Box::new(AssertKind::MisalignedPointerDereference {
required: Operand::Copy(alignment),
found: Operand::Copy(addr),
}),
}
}
Loading

0 comments on commit f89ffc1

Please sign in to comment.