diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 5fd1eb891a..9da4873957 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -10,6 +10,7 @@ use super::{get_name, get_names}; use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand}; use rspirv::spirv::{FunctionControl, Op, StorageClass, Word}; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_errors::ErrorGuaranteed; use rustc_session::Session; use std::mem::take; @@ -17,9 +18,8 @@ type FunctionMap = FxHashMap; pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion - if module_has_recursion(sess, module) { - return Err(rustc_errors::ErrorReported); - } + deny_recursion_in_module(sess, module)?; + let functions = module .functions .iter() @@ -52,7 +52,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { let names = get_names(module); for f in inlined_dont_inlines { sess.warn(&format!( - "Function `{}` has `dont_inline` attribute, but need to be inlined because it has illegal argument or return types", + "function `{}` has `dont_inline` attribute, but need to be inlined because it has illegal argument or return types", get_name(&names, f) )); } @@ -81,7 +81,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { } // https://stackoverflow.com/a/53995651 -fn module_has_recursion(sess: &Session, module: &Module) -> bool { +fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()> { let func_to_index: FxHashMap = module .functions .iter() @@ -90,7 +90,7 @@ fn module_has_recursion(sess: &Session, module: &Module) -> bool { .collect(); let mut discovered = vec![false; module.functions.len()]; let mut finished = vec![false; module.functions.len()]; - let mut has_recursion = false; + let mut has_recursion = None; for index in 0..module.functions.len() { if !discovered[index] && !finished[index] { visit( @@ -111,7 +111,7 @@ fn module_has_recursion(sess: &Session, module: &Module) -> bool { current: usize, discovered: &mut Vec, finished: &mut Vec, - has_recursion: &mut bool, + has_recursion: &mut Option, func_to_index: &FxHashMap, ) { discovered[current] = true; @@ -121,11 +121,10 @@ fn module_has_recursion(sess: &Session, module: &Module) -> bool { let names = get_names(module); let current_name = get_name(&names, module.functions[current].def_id().unwrap()); let next_name = get_name(&names, module.functions[next].def_id().unwrap()); - sess.err(&format!( + *has_recursion = Some(sess.err(&format!( "module has recursion, which is not allowed: `{}` calls `{}`", current_name, next_name - )); - *has_recursion = true; + ))); break; } @@ -159,7 +158,10 @@ fn module_has_recursion(sess: &Session, module: &Module) -> bool { }) } - has_recursion + match has_recursion { + Some(err) => Err(err), + None => Ok(()), + } } fn compute_disallowed_argument_and_return_types( diff --git a/crates/rustc_codegen_spirv/src/linker/inline_globals.rs b/crates/rustc_codegen_spirv/src/linker/inline_globals.rs index 707e765901..ab712ba11e 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline_globals.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline_globals.rs @@ -1,5 +1,5 @@ use rspirv::dr::{Instruction, Module, Operand}; -use rspirv::spirv::{Op}; +use rspirv::spirv::Op; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_session::Session; @@ -47,17 +47,17 @@ impl NormalizedInstructions { bound: &mut u32, new_root: u32, ) { - for op in &mut inst.operands { - match op { - Operand::IdRef(id) => match id_map.get(id) { - Some(new_id) => { - *id = *new_id; - } - _ => {} - }, + for op in &mut inst.operands { + match op { + Operand::IdRef(id) => match id_map.get(id) { + Some(new_id) => { + *id = *new_id; + } _ => {} - } + }, + _ => {} } + } if let Some(id) = &mut inst.result_id { if *id != root { id_map.insert(*id, *bound); @@ -144,7 +144,13 @@ fn inline_global_varaibles_rec(module: &mut Module) -> super::Result { match &inst.operands[i] { &Operand::IdRef(w) => match &function_args.get(&key) { None => { - match get_const_arg_insts(bound, &variables, &insts, &ref_stores, w) { + match get_const_arg_insts( + bound, + &variables, + &insts, + &ref_stores, + w, + ) { Some(insts) => { is_invalid = false; function_args.insert(key, FunctionArg::Insts(insts)); @@ -153,8 +159,13 @@ fn inline_global_varaibles_rec(module: &mut Module) -> super::Result { } } Some(FunctionArg::Insts(w2)) => { - let new_insts = - get_const_arg_insts(bound, &variables, &insts, &ref_stores, w); + let new_insts = get_const_arg_insts( + bound, + &variables, + &insts, + &ref_stores, + w, + ); match new_insts { Some(new_insts) => { is_invalid = new_insts != *w2; diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 5fd0ad0f05..d9f896ac0b 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -6,8 +6,8 @@ mod destructure_composites; mod duplicates; mod entry_interface; mod import_export_link; -mod inline_globals; mod inline; +mod inline_globals; mod ipo; mod mem2reg; mod param_weakening; @@ -153,7 +153,6 @@ pub fn link(sess: &Session, mut inputs: Vec, opts: &Options) -> Result, opts: &Options) -> Result