Skip to content
3 changes: 3 additions & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ impl TypeTree {
}
Self(ints)
}
pub fn add_indirection(self) -> Self {
Self(vec![Type { offset: 0, size: 1, kind: Kind::Pointer, child: self }])
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)]
Expand Down
12 changes: 12 additions & 0 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ trait ArgAttributesExt {
callsite: &Value,
);
}
use crate::abi::ty::print::with_no_trimmed_paths;
use rustc_codegen_ssa::mir::operand::scalar_pair_component_field_ty;

const ABI_AFFECTING_ATTRIBUTES: [(ArgAttribute, llvm::AttributeKind); 1] =
[(ArgAttribute::InReg, llvm::AttributeKind::InReg)];
Expand Down Expand Up @@ -262,6 +264,16 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
let llscratch = bx.alloca(scratch_size, scratch_align);
bx.lifetime_start(llscratch, scratch_size);
// ...store the value...

let f0 = scalar_pair_component_field_ty(bx, dst.layout, 0);
let f1 = scalar_pair_component_field_ty(bx, dst.layout, 1);

if f1.is_some() && f0.is_some() {
with_no_trimmed_paths!({
eprintln!("Cast of extractvalue 0 field = {:?}", f0.map(|f| f0.unwrap()));
eprintln!("Cast of extractvalue 1 field = {:?}", f1.map(|f| f1.unwrap()));
});
}
rustc_codegen_ssa::mir::store_cast(bx, cast, val, llscratch, scratch_align);
// ... and then memcpy it to the intended destination.
bx.memcpy(
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
let value = if output_types.len() == 1 {
result
} else {
self.extract_value(result, op_idx[&idx] as u64)
self.extract_value(result, op_idx[&idx] as u64, None)
};
let value =
llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance);
Expand Down
116 changes: 100 additions & 16 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::{Borrow, Cow};
use std::iter;
use std::ops::Deref;

use rustc_ast::expand::typetree::FncTree;
use rustc_ast::expand::typetree::{TypeTree, FncTree};
pub(crate) mod autodiff;
pub(crate) mod gpu_offload;

Expand Down Expand Up @@ -38,6 +38,7 @@ use crate::llvm::{
ToLlvmBool, Type, Value,
};
use crate::type_of::LayoutLlvmExt;
use rustc_middle::ty::type_tree::typetree_from_ty;

#[must_use]
pub(crate) struct GenericBuilder<'a, 'll, CX: Borrow<SCx<'ll>>> {
Expand Down Expand Up @@ -181,11 +182,12 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
}

pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value {
unsafe {
let load = unsafe {
let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED);
llvm::LLVMSetAlignment(load, align.bytes() as c_uint);
load
}
};
load
}
}

Expand Down Expand Up @@ -585,7 +587,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
let name = format!("llvm.{}{oop_str}.with.overflow", if signed { 's' } else { 'u' });

let res = self.call_intrinsic(name, &[self.type_ix(width)], &[lhs, rhs]);
(self.extract_value(res, 0), self.extract_value(res, 1))
(self.extract_value(res, 0, None), self.extract_value(res, 1, None))
}

fn from_immediate(&mut self, val: Self::Value) -> Self::Value {
Expand Down Expand Up @@ -627,13 +629,17 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}

fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value {
unsafe {
fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align, tt: Option<FncTree>) -> &'ll Value {
let load = unsafe {
let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED);
let align = align.min(self.cx().tcx.sess.target.max_reliable_alignment());
llvm::LLVMSetAlignment(load, align.bytes() as c_uint);
load
};
if let Some(tt) = tt {
//crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, tt);
}
load
}

fn volatile_load(&mut self, ty: &'ll Type, ptr: &'ll Value) -> &'ll Value {
Expand Down Expand Up @@ -734,6 +740,70 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {

let llval = const_llval.unwrap_or_else(|| {
let load = self.load(llty, place.val.llval, place.val.align);
//let layout = place.layout.ty_and_layout_pointee_info_at(self.cx(), Size::ZERO).unwrap();
let ty = place.layout.ty;
let tt = typetree_from_ty(self.tcx, ty);
if tt != rustc_ast::expand::typetree::TypeTree::new() {
use rustc_middle::ty::print::with_no_trimmed_paths;
//dbg!("add_tt start!");
//dbg!(&load);
//dbg!(&tt);
//eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}")));
let fnc_tree = FncTree {
args: vec![TypeTree::new(), TypeTree::new()],
ret: tt,
};
// TODO: re-enable?
//crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, fnc_tree);
//dbg!("add_tt done!");
}
//eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}")));
// 25 general load of place = PlaceRef {
// 24 val: PlaceValue {
// 23 llval: (ptr: %3 = alloca [8 x i8], align 8),
// 22 llextra: None,
// 21 align: Align(8 bytes),
// 20 },
// 19 layout: TyAndLayout {
// 18 ty: &([f64; 3], [f64; 3]),
// 17 layout: Layout {
// 16 size: Size(8 bytes),
// 15 align: AbiAlign {
// 14 abi: Align(8 bytes),
// 13 },
// 12 backend_repr: Scalar(
// 11 Initialized {
// 10 value: Pointer(
// 9 AddressSpace(
// 8 0,
// 7 ),
// 6 ),
// 5 valid_range: 1..=18446744073709551615,
// 4 },
// 3 ),
// 2 fields: Primitive,
// 1 largest_niche: Some(
// 259 Niche {
// 1 offset: Size(0 bytes),
// 2 value: Pointer(
// 3 AddressSpace(
// 4 0,
// 5 ),
// 6 ),
// 7 valid_range: 1..=18446744073709551615,
// 8 },
// 9 ),
// 10 uninhabited: false,
// 11 variants: Single {
// 12 index: 0,
// 13 },
// 14 max_repr_align: None,
// 15 unadjusted_abi_align: Align(8 bytes),
// 16 randomization_seed: 281492156579847,
// 17 },
// 18 },
// 19 }

if let abi::BackendRepr::Scalar(scalar) = place.layout.backend_repr {
scalar_load_metadata(self, load, scalar, place.layout, Size::ZERO);
self.to_immediate_scalar(load, scalar)
Expand Down Expand Up @@ -1113,7 +1183,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, self.tcx, tt);
}
}

Expand All @@ -1125,11 +1195,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let memmove = unsafe {
llvm::LLVMRustBuildMemMove(
self.llbuilder,
dst,
Expand All @@ -1138,7 +1209,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memmove, self.tcx, tt);
}
}

Expand All @@ -1149,18 +1223,22 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
size: &'ll Value,
align: Align,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memset not supported");
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let memset = unsafe {
llvm::LLVMRustBuildMemSet(
self.llbuilder,
ptr,
align.bytes() as c_uint,
fill_byte,
size,
is_volatile,
);
)
};
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memset, self.tcx, tt);
}
}

Expand Down Expand Up @@ -1191,9 +1269,15 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}

fn extract_value(&mut self, agg_val: &'ll Value, idx: u64) -> &'ll Value {
fn extract_value(&mut self, agg_val: &'ll Value, idx: u64, tt: Option<FncTree>) -> &'ll Value {
assert_eq!(idx as c_uint as u64, idx);
unsafe { llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) }
let ev = unsafe {
llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED)
};
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, ev, self.tcx, tt);
}
ev
}

fn insert_value(&mut self, agg_val: &'ll Value, elt: &'ll Value, idx: u64) -> &'ll Value {
Expand All @@ -1213,7 +1297,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
unsafe {
llvm::LLVMSetCleanup(landing_pad, llvm::TRUE);
}
(self.extract_value(landing_pad, 0), self.extract_value(landing_pad, 1))
(self.extract_value(landing_pad, 0, None), self.extract_value(landing_pad, 1, None))
}

fn filter_landing_pad(&mut self, pers_fn: &'ll Value) {
Expand Down Expand Up @@ -1314,8 +1398,8 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
llvm::FALSE, // SingleThreaded
);
llvm::LLVMSetWeak(value, weak.to_llvm_bool());
let val = self.extract_value(value, 0);
let success = self.extract_value(value, 1);
let val = self.extract_value(value, 0, None);
let success = self.extract_value(value, 1, None);
(val, success)
}
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
tcx: TyCtxt<'tcx>,
cx: &SimpleCx<'ll>,
fn_to_diff: &'ll Value,
outer_name: &str,
Expand Down Expand Up @@ -375,7 +376,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
);

if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, tcx, fnc_tree);
}

let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ impl<'ll> OffloadKernelDims<'ll> {
builder: &mut Builder<'_, 'll, 'tcx>,
arr: &'ll Value,
) -> &'ll Value {
let x = builder.extract_value(arr, 0);
let y = builder.extract_value(arr, 1);
let z = builder.extract_value(arr, 2);
let x = builder.extract_value(arr, 0, None);
let y = builder.extract_value(arr, 1, None);
let z = builder.extract_value(arr, 2, None);

let xy = builder.mul(x, y);
builder.mul(xy, z)
Expand Down
Loading