From 13c22af469ec7102402b72693beeb30cb3dd791e Mon Sep 17 00:00:00 2001 From: Vaivaswatha Nagaraj Date: Thu, 4 Dec 2025 09:34:03 +0530 Subject: [PATCH 01/29] [sway-ir]: Add switch instruction --- .../src/asm_generation/evm/evm_asm_builder.rs | 16 ++++ .../asm_generation/fuel/fuel_asm_builder.rs | 15 ++++ sway-ir/src/analysis/memory_utils.rs | 3 + sway-ir/src/error.rs | 7 ++ sway-ir/src/instruction.rs | 79 +++++++++++++++++++ sway-ir/src/optimize/arg_mutability_tagger.rs | 4 +- sway-ir/src/optimize/cse.rs | 1 + sway-ir/src/optimize/fn_dedup.rs | 11 +++ sway-ir/src/optimize/inline.rs | 26 +++++- sway-ir/src/parser.rs | 55 ++++++++++++- sway-ir/src/printer.rs | 52 ++++++++++++ sway-ir/src/verify.rs | 34 ++++++++ sway-ir/tests/simplify_cfg/switches.ir | 20 +++++ 13 files changed, 319 insertions(+), 4 deletions(-) create mode 100644 sway-ir/tests/simplify_cfg/switches.ir diff --git a/sway-core/src/asm_generation/evm/evm_asm_builder.rs b/sway-core/src/asm_generation/evm/evm_asm_builder.rs index 811dde8d9ed..e6de8532e6e 100644 --- a/sway-core/src/asm_generation/evm/evm_asm_builder.rs +++ b/sway-core/src/asm_generation/evm/evm_asm_builder.rs @@ -319,6 +319,11 @@ impl<'ir, 'eng> EvmAsmBuilder<'ir, 'eng> { } => { self.compile_conditional_branch(handler, cond_value, true_block, false_block)? } + InstOp::Switch { + discriminant, + cases, + default, + } => self.compile_switch(handler, instr_val, discriminant, cases, default)?, InstOp::ContractCall { params, coins, @@ -443,6 +448,17 @@ impl<'ir, 'eng> EvmAsmBuilder<'ir, 'eng> { todo!(); } + fn compile_switch( + &mut self, + handler: &Handler, + instr_val: &Value, + discriminant: &Value, + cases: &[(u64, BranchToWithArgs)], + default: &BranchToWithArgs, + ) -> Result<(), ErrorEmitted> { + todo!(); + } + fn compile_branch_to_phi_value(&mut self, to_block: &BranchToWithArgs) { todo!(); } diff --git a/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs b/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs index 0d9ca7af264..aa50aa40e5c 100644 --- a/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs +++ b/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs @@ -384,6 +384,11 @@ impl<'ir, 'eng> FuelAsmBuilder<'ir, 'eng> { true_block, false_block, } => self.compile_conditional_branch(cond_value, true_block, false_block), + InstOp::Switch { + discriminant, + cases, + default, + } => self.compile_switch(instr_val, discriminant, cases, default), InstOp::ContractCall { params, coins, @@ -1042,6 +1047,16 @@ impl<'ir, 'eng> FuelAsmBuilder<'ir, 'eng> { Ok(()) } + fn compile_switch( + &mut self, + _instr_val: &Value, + _discriminant: &Value, + _cases: &[(u64, BranchToWithArgs)], + _default: &BranchToWithArgs, + ) -> Result<(), CompileError> { + todo!() + } + fn compile_branch_to_phi_value( &mut self, to_block: &BranchToWithArgs, diff --git a/sway-ir/src/analysis/memory_utils.rs b/sway-ir/src/analysis/memory_utils.rs index c30a7152b7b..d7fa8c37505 100644 --- a/sway-ir/src/analysis/memory_utils.rs +++ b/sway-ir/src/analysis/memory_utils.rs @@ -433,6 +433,7 @@ fn compute_escaped_symbols(context: &Context, function: &Function) -> EscapedSym InstOp::CastPtr(ptr, _) => add_from_val(&mut result, ptr, &mut is_complete), InstOp::Cmp(_, _, _) => (), InstOp::ConditionalBranch { .. } => (), + InstOp::Switch { .. } => (), InstOp::ContractCall { params, .. } => { add_from_val(&mut result, params, &mut is_complete) } @@ -472,6 +473,7 @@ pub fn get_loaded_ptr_values(context: &Context, inst: Value) -> Vec { | InstOp::BitCast(_, _) | InstOp::Branch(_) | InstOp::ConditionalBranch { .. } + | InstOp::Switch { .. } | InstOp::Cmp(_, _, _) | InstOp::Nop | InstOp::CastPtr(_, _) @@ -561,6 +563,7 @@ pub fn get_stored_ptr_values(context: &Context, inst: Value) -> Vec { | InstOp::BitCast(_, _) | InstOp::Branch(_) | InstOp::ConditionalBranch { .. } + | InstOp::Switch { .. } | InstOp::Cmp(_, _, _) | InstOp::Nop | InstOp::PtrToInt(_, _) diff --git a/sway-ir/src/error.rs b/sway-ir/src/error.rs index 3e7b07363a1..d7bbafa292d 100644 --- a/sway-ir/src/error.rs +++ b/sway-ir/src/error.rs @@ -31,6 +31,7 @@ pub enum IrError { VerifyCmpTypeMismatch(String, String), VerifyCmpUnknownTypes, VerifyConditionExprNotABool, + VerifySwitchDiscriminantNotU64, VerifyContractCallBadTypes(String), VerifyGepElementTypeNonPointer, VerifyGepFromNonPointer(String, Option), @@ -204,6 +205,12 @@ impl fmt::Display for IrError { "Verification failed: Expression used for conditional is not a boolean." ) } + IrError::VerifySwitchDiscriminantNotU64 => { + write!( + f, + "Verification failed: Switch discriminant is not a u64 integer." + ) + } IrError::VerifyContractCallBadTypes(arg_name) => { write!( f, diff --git a/sway-ir/src/instruction.rs b/sway-ir/src/instruction.rs index f1377cc480b..f21b218d4dc 100644 --- a/sway-ir/src/instruction.rs +++ b/sway-ir/src/instruction.rs @@ -79,6 +79,11 @@ pub enum InstOp { true_block: BranchToWithArgs, false_block: BranchToWithArgs, }, + Switch { + discriminant: Value, + cases: Vec<(u64, BranchToWithArgs)>, + default: BranchToWithArgs, + }, /// A contract call with a list of arguments ContractCall { return_type: Type, @@ -436,6 +441,7 @@ impl InstOp { // These are all terminators which don't return, essentially. No type. InstOp::Branch(_) | InstOp::ConditionalBranch { .. } + | InstOp::Switch { .. } | InstOp::FuelVm( FuelVmInstruction::Revert(..) | FuelVmInstruction::JmpMem @@ -495,6 +501,18 @@ impl InstOp { v.extend_from_slice(&false_block.args); v } + InstOp::Switch { + discriminant, + cases, + default, + } => { + let mut v = vec![*discriminant]; + v.extend_from_slice(&default.args); + for case in cases { + v.extend_from_slice(&case.1.args); + } + v + } InstOp::ContractCall { return_type: _, name: _, @@ -691,6 +709,30 @@ impl InstOp { panic!("Invalid index for ConditionalBranch"); } } + InstOp::Switch { + discriminant, + cases, + default, + } => { + if idx == 0 { + *discriminant = replacement; + } else { + let mut cur_idx = 1; + if idx - cur_idx < default.args.len() { + default.args[idx - cur_idx] = replacement; + } else { + cur_idx += default.args.len(); + for case in cases.iter_mut() { + if idx - cur_idx < case.1.args.len() { + case.1.args[idx - cur_idx] = replacement; + return; + } + cur_idx += case.1.args.len(); + } + panic!("Invalid index for Switch"); + } + } + } InstOp::ContractCall { return_type: _, name: _, @@ -1039,6 +1081,17 @@ impl InstOp { true_block.args.iter_mut().for_each(replace); false_block.args.iter_mut().for_each(replace); } + InstOp::Switch { + discriminant, + cases, + default, + } => { + replace(discriminant); + default.args.iter_mut().for_each(replace); + for case in cases { + case.1.args.iter_mut().for_each(replace); + } + } InstOp::ContractCall { params, coins, @@ -1215,6 +1268,7 @@ impl InstOp { | InstOp::CastPtr { .. } | InstOp::Cmp(..) | InstOp::ConditionalBranch { .. } + | InstOp::Switch { .. } | InstOp::FuelVm(FuelVmInstruction::Gtf { .. }) | InstOp::FuelVm(FuelVmInstruction::ReadRegister(_)) | InstOp::FuelVm(FuelVmInstruction::StateLoadWord(_)) @@ -1241,6 +1295,7 @@ impl InstOp { self, InstOp::Branch(_) | InstOp::ConditionalBranch { .. } + | InstOp::Switch { .. } | InstOp::Ret(..) | InstOp::FuelVm( FuelVmInstruction::Revert(..) @@ -1526,6 +1581,30 @@ impl<'a, 'eng> InstructionInserter<'a, 'eng> { cbr_val } + pub fn switch( + self, + discriminant: Value, + default: BranchToWithArgs, + cases: Vec<(u64, BranchToWithArgs)>, + ) -> Value { + for case_block in std::iter::once(&default).chain(cases.iter().map(|c| &c.1)) { + case_block.block.add_pred(self.context, &self.block); + } + let switch_val = Value::new_instruction( + self.context, + self.block, + InstOp::Switch { + discriminant, + cases, + default, + }, + ); + self.context.blocks[self.block.0] + .instructions + .push(switch_val); + switch_val + } + pub fn contract_call( self, return_type: Type, diff --git a/sway-ir/src/optimize/arg_mutability_tagger.rs b/sway-ir/src/optimize/arg_mutability_tagger.rs index 6b73cd30bbc..10d84987be2 100644 --- a/sway-ir/src/optimize/arg_mutability_tagger.rs +++ b/sway-ir/src/optimize/arg_mutability_tagger.rs @@ -183,7 +183,9 @@ fn analyse_fn( match &ctx.values.get(value.0).unwrap().value { ValueDatum::Instruction(inst) => match &inst.op { - InstOp::ConditionalBranch { .. } | InstOp::Branch(_) => { + InstOp::ConditionalBranch { .. } + | InstOp::Branch(_) + | InstOp::Switch { .. } => { // Branch instructions do not mutate anything. // They do pass arguments to the next block, // but that is captured by that argument itself being diff --git a/sway-ir/src/optimize/cse.rs b/sway-ir/src/optimize/cse.rs index 5e8fbeb9de0..772d8e95525 100644 --- a/sway-ir/src/optimize/cse.rs +++ b/sway-ir/src/optimize/cse.rs @@ -99,6 +99,7 @@ fn instr_to_expr(context: &Context, vntable: &VNTable, instr: Value) -> Option None, + InstOp::Switch { .. } => None, InstOp::ContractCall { .. } => None, InstOp::FuelVm(_) => None, InstOp::GetLocal(_) => None, diff --git a/sway-ir/src/optimize/fn_dedup.rs b/sway-ir/src/optimize/fn_dedup.rs index e7b32f11d8f..8d7f50e8f4b 100644 --- a/sway-ir/src/optimize/fn_dedup.rs +++ b/sway-ir/src/optimize/fn_dedup.rs @@ -226,6 +226,17 @@ fn hash_fn( get_localised_id(true_block.block, localised_block_id).hash(state); get_localised_id(false_block.block, localised_block_id).hash(state); } + crate::InstOp::Switch { + discriminant: _, + cases, + default, + } => { + get_localised_id(default.block, localised_block_id).hash(state); + for (case_val, branch) in cases { + case_val.hash(state); + get_localised_id(branch.block, localised_block_id).hash(state); + } + } crate::InstOp::ContractCall { name, .. } => { name.hash(state); } diff --git a/sway-ir/src/optimize/inline.rs b/sway-ir/src/optimize/inline.rs index 31854228215..e1a02cd5bcd 100644 --- a/sway-ir/src/optimize/inline.rs +++ b/sway-ir/src/optimize/inline.rs @@ -18,7 +18,8 @@ use crate::{ metadata::{combine, MetadataIndex}, value::{Value, ValueContent, ValueDatum}, variable::LocalVar, - AnalysisResults, BlockArgument, Instruction, Module, Pass, PassMutability, ScopedPass, + AnalysisResults, BlockArgument, BranchToWithArgs, Instruction, Module, Pass, PassMutability, + ScopedPass, }; pub const FN_INLINE_NAME: &str = "inline"; @@ -486,6 +487,29 @@ fn inline_instruction( true_block.args.iter().map(|v| map_value(*v)).collect(), false_block.args.iter().map(|v| map_value(*v)).collect(), ), + InstOp::Switch { + discriminant, + cases, + default, + } => new_block.append(context).switch( + map_value(discriminant), + BranchToWithArgs { + block: map_block(default.block), + args: default.args.iter().map(|v| map_value(*v)).collect(), + }, + cases + .iter() + .map(|(val, branch)| { + ( + *val, + BranchToWithArgs { + block: map_block(branch.block), + args: branch.args.iter().map(|v| map_value(*v)).collect(), + }, + ) + }) + .collect(), + ), InstOp::ContractCall { return_type, name, diff --git a/sway-ir/src/parser.rs b/sway-ir/src/parser.rs index dd53117f8d1..729098065f0 100644 --- a/sway-ir/src/parser.rs +++ b/sway-ir/src/parser.rs @@ -256,6 +256,7 @@ mod ir_builder { / op_call() / op_cast_ptr() / op_cbr() + / op_switch() / op_cmp() / op_const() / op_contract_call() @@ -360,6 +361,19 @@ mod ir_builder { IrAstOperation::Cbr(cond, tblock, targs, fblock, fargs) } + rule switch_case() -> (u64, String, Vec) + = case_val:decimal() _ ":" _ case_block:id() + "(" _ case_args:(id() ** comma()) _ ")" { + (case_val, case_block, case_args) + } + + rule op_switch() -> IrAstOperation + = "switch" _ discrim:id() comma() "default" _ ":" _ dblock:id() + "(" _ dargs:(id() ** comma()) ")" _ + comma() "[" _ cases:(switch_case() ** comma()) _ "]" _ { + IrAstOperation::Switch(discrim, dblock, dargs, cases) + } + rule op_cmp() -> IrAstOperation = "cmp" _ p:cmp_pred() l:id() r:id() { IrAstOperation::Cmp(p, l, r) @@ -794,8 +808,8 @@ mod ir_builder { module::{Kind, Module}, value::Value, variable::LocalVar, - Backtrace, BinaryOpKind, BlockArgument, ConfigContent, Constant, GlobalVar, Instruction, - LogEventData, StorageKey, UnaryOpKind, B256, + Backtrace, BinaryOpKind, BlockArgument, BranchToWithArgs, ConfigContent, Constant, + GlobalVar, Instruction, LogEventData, StorageKey, UnaryOpKind, B256, }; #[derive(Debug)] @@ -870,6 +884,8 @@ mod ir_builder { Call(String, Vec), CastPtr(String, IrAstTy), Cbr(String, String, Vec, String, Vec), + // (descriminant, default_block, default_args, [(u64, block, args)]) + Switch(String, String, Vec, Vec<(u64, String, Vec)>), Cmp(Predicate, String, String), Const(IrAstTy, IrAstConst), ContractCall(IrAstTy, String, String, String, String, String), @@ -1433,6 +1449,41 @@ mod ir_builder { .collect(), ) .add_metadatum(context, opt_metadata), + IrAstOperation::Switch( + discrim_name, + default_block_name, + default_args, + cases, + ) => { + let descrim_val = *val_map.get(&discrim_name).unwrap(); + let default_block = named_blocks.get(&default_block_name).unwrap(); + let default_args_vals = default_args + .iter() + .map(|arg| *val_map.get(arg).unwrap()) + .collect(); + let case_blocks: Vec<(u64, BranchToWithArgs)> = cases + .into_iter() + .map(|(case_val, block_name, block_args)| { + let block = *named_blocks.get(&block_name).unwrap(); + let args = block_args + .into_iter() + .map(|arg| *val_map.get(&arg).unwrap()) + .collect(); + (case_val, BranchToWithArgs { block, args }) + }) + .collect(); + block + .append(context) + .switch( + descrim_val, + BranchToWithArgs { + block: *default_block, + args: default_args_vals, + }, + case_blocks, + ) + .add_metadatum(context, opt_metadata) + } IrAstOperation::Cmp(pred, lhs, rhs) => block .append(context) .cmp( diff --git a/sway-ir/src/printer.rs b/sway-ir/src/printer.rs index 939298b2190..5a085f5d03a 100644 --- a/sway-ir/src/printer.rs +++ b/sway-ir/src/printer.rs @@ -752,6 +752,58 @@ fn instruction_to_doc<'a>( ), )) } + InstOp::Switch { + discriminant, + cases, + default, + } => { + // Handle possibly constant values + let doc = maybe_constant_to_doc(context, md_namer, namer, discriminant); + let doc = std::iter::once(default) + .chain(cases.iter().map(|(_val, branch)| branch)) + .fold(doc, |doc, branch| { + branch.args.iter().fold(doc, |doc, param| { + doc.append(maybe_constant_to_doc(context, md_namer, namer, param)) + }) + }); + + let default_label = &context.blocks[default.block.0].label; + let default_args = Doc::in_parens_comma_sep( + default + .args + .iter() + .map(|arg_val| Doc::text(namer.name(context, arg_val))) + .collect(), + ); + let case_labels = cases + .iter() + .map(|(val, branch)| { + let label = &context.blocks[branch.block.0].label; + let args_doc = Doc::in_parens_comma_sep( + branch + .args + .iter() + .map(|arg_val| Doc::text(namer.name(context, arg_val))) + .collect(), + ); + Doc::text(format!("{val}: {label}")).append(args_doc) + }) + .collect::>(); + + doc.append( + Doc::line( + Doc::text(format!( + "switch {}, default: {default_label}", + namer.name(context, discriminant), + )) + .append(default_args) + .append(Doc::text(", [")) + .append(Doc::list_sep(case_labels, Doc::text(", "))) + .append(Doc::text("]")), + ) + .append(md_namer.md_idx_to_doc(context, metadata)), + ) + } InstOp::ContractCall { return_type, name, diff --git a/sway-ir/src/verify.rs b/sway-ir/src/verify.rs index 97552778063..5584fe2964c 100644 --- a/sway-ir/src/verify.rs +++ b/sway-ir/src/verify.rs @@ -278,6 +278,11 @@ impl InstructionVerifier<'_, '_> { true_block, false_block, } => self.verify_cbr(cond_value, true_block, false_block)?, + InstOp::Switch { + discriminant, + cases, + default, + } => self.verify_switch(discriminant, cases, default)?, InstOp::ContractCall { params, coins, @@ -711,6 +716,35 @@ impl InstructionVerifier<'_, '_> { } } + fn verify_switch( + &self, + descriminant: &Value, + cases: &[(u64, BranchToWithArgs)], + default: &BranchToWithArgs, + ) -> Result<(), IrError> { + if !descriminant + .get_type(self.context) + .is(Type::is_uint64, self.context) + { + return Err(IrError::VerifySwitchDiscriminantNotU64); + } + + for dest_block in std::iter::once(default).chain(cases.iter().map(|(_, branch)| branch)) { + if !self + .cur_function + .block_iter(self.context) + .contains(&dest_block.block) + { + return Err(IrError::VerifyBranchToMissingBlock( + self.context.blocks[dest_block.block.0].label.clone(), + )); + } + self.verify_dest_args(dest_block)?; + } + + Ok(()) + } + fn verify_cmp( &self, _pred: &Predicate, diff --git a/sway-ir/tests/simplify_cfg/switches.ir b/sway-ir/tests/simplify_cfg/switches.ir new file mode 100644 index 00000000000..e5f57c04b34 --- /dev/null +++ b/sway-ir/tests/simplify_cfg/switches.ir @@ -0,0 +1,20 @@ +// regex: VAR=v\d+v\d+ + +script { + fn main() -> u64 { + entry(): + v0 = const u64 5 + v2 = const u64 1 + br testblock() + + testblock(): + // check: switch + switch v0, default: defblock(v2), [5: block1()] + + defblock(v1: u64): + ret u64 v1 + + block1(): + ret u64 v0 + } +} From cf2e018ba89441251db43c044462ae98e647267b Mon Sep 17 00:00:00 2001 From: Vaivaswatha Nagaraj Date: Fri, 12 Dec 2025 18:08:51 +0530 Subject: [PATCH 02/29] (incomplete) introduce Switch in abstract asm --- .../asm_generation/fuel/fuel_asm_builder.rs | 65 +++++++++++++++++-- sway-core/src/asm_lang/mod.rs | 31 +++++++++ 2 files changed, 91 insertions(+), 5 deletions(-) diff --git a/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs b/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs index aa50aa40e5c..a727a778655 100644 --- a/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs +++ b/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs @@ -1049,12 +1049,67 @@ impl<'ir, 'eng> FuelAsmBuilder<'ir, 'eng> { fn compile_switch( &mut self, - _instr_val: &Value, - _discriminant: &Value, - _cases: &[(u64, BranchToWithArgs)], - _default: &BranchToWithArgs, + instr_val: &Value, + discriminant: &Value, + cases: &[(u64, BranchToWithArgs)], + default: &BranchToWithArgs, ) -> Result<(), CompileError> { - todo!() + for dest_block in cases + .iter() + .map(|(_, bb)| bb) + .chain(std::iter::once(default)) + { + self.compile_branch_to_phi_value(dest_block)?; + } + + let default_label = self.block_to_label(&default.block); + if cases.len() == 0 { + // No cases, just jump to default. + self.cur_bytecode.push(Op::jump_to_label(default_label)); + return Ok(()); + } + + // TODO: Decide on a better limit for number of cases. + assert!(cases.len() < 20, "Too many switch cases to compile"); + + // Sort the cases by their values to make range checking easier. + let mut sorted_cases = cases.to_vec(); + sorted_cases.sort_by_key(|(val, _)| *val); + + // If the lowest case value isn't 0, we subtract the descriminant and each + // case value by that amount to make the lowest case 0. + let min_case_value = sorted_cases.first().unwrap().0; + let discrim_reg = self.value_to_register(discriminant)?; + if min_case_value > 0 { + self.cur_bytecode.push(Op { + opcode: Either::Left(VirtualOp::SUBI( + discrim_reg.clone(), + discrim_reg.clone(), + VirtualImmediate12::new(min_case_value), + )), + comment: format!( + "adjust switch discriminant by subtracting min case value {}", + min_case_value + ), + owning_span: self.md_mgr.val_to_span(self.context, *instr_val), + }); + sorted_cases + .iter_mut() + .for_each(|(val, _)| *val -= min_case_value); + } + + let cases = cases + .into_iter() + .map(|(val, BranchToWithArgs { block, args: _ })| (*val, self.block_to_label(block))) + .collect(); + self.cur_bytecode.push(Op::switch_comment( + discrim_reg, + cases, + default_label, + format!("switch with {} cases", sorted_cases.len()), + )); + + Ok(()) } fn compile_branch_to_phi_value( diff --git a/sway-core/src/asm_lang/mod.rs b/sway-core/src/asm_lang/mod.rs index 8a2f75fd761..7f7e4310043 100644 --- a/sway-core/src/asm_lang/mod.rs +++ b/sway-core/src/asm_lang/mod.rs @@ -281,6 +281,29 @@ impl Op { } } + /// Switch + pub(crate) fn switch(on: VirtualRegister, cases: Vec<(u64, Label)>, default: Label) -> Self { + Op { + opcode: Either::Right(OrganizationalOp::Switch { discriminant: on, cases, default }), + comment: String::new(), + owning_span: None, + } + } + + /// Switch with comment + pub(crate) fn switch_comment( + discriminant: VirtualRegister, + cases: Vec<(u64, Label)>, + default: Label, + comment: impl Into, + ) -> Self { + Op { + opcode: Either::Right(OrganizationalOp::Switch { discriminant, cases, default }), + comment: comment.into(), + owning_span: None, + } + } + pub(crate) fn parse_opcode( handler: &Handler, name: &Ident, @@ -1322,6 +1345,14 @@ pub(crate) enum ControlFlowOp { /// Jump type type_: JumpType, }, + Switch { + /// The register to switch on + discriminant: Reg, + /// Mapping of values to labels + cases: Vec<(u64, Label)>, + /// Default label if no cases match + default: Label, + }, // Placeholder for the offset into the configurables section. ConfigurablesOffsetPlaceholder, // placeholder for the DataSection offset From 1afcac8860511024858db4c61a11dd7a4e733f0f Mon Sep 17 00:00:00 2001 From: Vaivaswatha Nagaraj Date: Wed, 7 Jan 2026 20:11:38 +0530 Subject: [PATCH 03/29] Assembly generation for switch --- forc-pkg/src/pkg.rs | 2 + sway-core/src/asm_generation/finalized_asm.rs | 13 +++ .../allocated_abstract_instruction_set.rs | 89 ++++++++++++++++++- .../src/asm_generation/fuel/data_section.rs | 32 +++++++ .../asm_generation/fuel/fuel_asm_builder.rs | 60 +++++++++++-- .../fuel/optimizations/constant_propagate.rs | 2 + sway-core/src/asm_lang/mod.rs | 50 +++++++++-- sway-ir/src/irtype.rs | 7 ++ 8 files changed, 236 insertions(+), 19 deletions(-) diff --git a/forc-pkg/src/pkg.rs b/forc-pkg/src/pkg.rs index 7175b6a7490..e0a7a7ab50f 100644 --- a/forc-pkg/src/pkg.rs +++ b/forc-pkg/src/pkg.rs @@ -1982,6 +1982,8 @@ fn report_assembly_information( } } + sway_core::asm_generation::Datum::WordArray(words) => (words.len() * 8) as u64, + sway_core::asm_generation::Datum::Collection(items) => { items.iter().map(calculate_entry_size).sum() } diff --git a/sway-core/src/asm_generation/finalized_asm.rs b/sway-core/src/asm_generation/finalized_asm.rs index fd289585ae5..1293faab5a4 100644 --- a/sway-core/src/asm_generation/finalized_asm.rs +++ b/sway-core/src/asm_generation/finalized_asm.rs @@ -323,6 +323,19 @@ fn to_bytecode_mut( } println!("\""); } + Datum::WordArray(ws) => { + print!(".words as hex ("); + + let mut first = true; + for w in ws { + if !first { + print!(", "); + } + first = false; + print!("{:02X?}", w.to_be_bytes()); + } + println!("), len i{}", ws.len()); + } Datum::Slice(bs) => { print!(".slice as hex ({bs:02X?}), len i{}, as ascii \"", bs.len()); diff --git a/sway-core/src/asm_generation/fuel/allocated_abstract_instruction_set.rs b/sway-core/src/asm_generation/fuel/allocated_abstract_instruction_set.rs index 785e00e3bf3..6cb4494384a 100644 --- a/sway-core/src/asm_generation/fuel/allocated_abstract_instruction_set.rs +++ b/sway-core/src/asm_generation/fuel/allocated_abstract_instruction_set.rs @@ -249,6 +249,77 @@ impl AllocatedAbstractInstructionSet { debug_assert_eq!(ops.len() as u64, op_size); realized_ops.extend(ops); } + ControlFlowOp::Switch { + discriminant, + cases, + } => { + let target_offsets = cases + .iter() + .map(|label| { + label_offsets + .get(label) + .map(|b| b.offs) + .expect("Switch case label not found in label_offsets") + }) + .collect::>(); + // Ensure that all target offsets are forward jumps. + target_offsets.iter().for_each(|&target_offset| { + if target_offset < curr_offset { + panic!("Switch case target offset is before current offset, which is not supported"); + } + }); + // Insert a data section entry for the switch targets. + let data_id = data_section.insert_data_value(Entry::new_word_array( + target_offsets, + EntryName::NonConfigurable, + None, + )); + realized_ops.push(RealizedOp { + opcode: AllocatedInstruction::AddrDataId( + AllocatedRegister::Constant(ConstantRegister::Scratch), + data_id, + ), + owning_span: owning_span.clone(), + comment: "load switch target table address".into(), + }); + // Multiply discriminant by 8 (since each address is 8 bytes) and add to the base address. + realized_ops.push(RealizedOp { + opcode: AllocatedInstruction::SLLI( + discriminant.clone(), + discriminant.clone(), + VirtualImmediate12::new(3), + ), + owning_span: owning_span.clone(), + comment: "multiply discriminant by 8".into(), + }); + realized_ops.push(RealizedOp { + opcode: AllocatedInstruction::ADD( + AllocatedRegister::Constant(ConstantRegister::Scratch), + discriminant, + AllocatedRegister::Constant(ConstantRegister::Scratch), + ), + owning_span: owning_span.clone(), + comment: "add discriminant to switch target table address".into(), + }); + realized_ops.push(RealizedOp { + opcode: AllocatedInstruction::LW( + AllocatedRegister::Constant(ConstantRegister::Scratch), + AllocatedRegister::Constant(ConstantRegister::Scratch), + VirtualImmediate12::new(0), + ), + owning_span: owning_span.clone(), + comment: "load switch target address".into(), + }); + // Finally, jump to the loaded address. + realized_ops.push(RealizedOp { + opcode: AllocatedInstruction::JMPF( + AllocatedRegister::Constant(ConstantRegister::Zero), + VirtualImmediate18::new(0), + ), + owning_span, + comment, + }); + } ControlFlowOp::DataSectionOffsetPlaceholder => { realized_ops.push(RealizedOp { opcode: AllocatedInstruction::DataSectionOffsetPlaceholder, @@ -332,6 +403,10 @@ impl AllocatedAbstractInstructionSet { JumpType::Call => 3, }, + // A switch expands to AddrDataId (2 opcodes) + scale descriminant by word size (1 opcode) + // + add descriminant (1 opcode) + load (1 opcode) + jump (1 opcode) = 6 opcodes + Either::Right(Switch { .. }) => 6, + Either::Right(Comment) => 0, Either::Right(DataSectionOffsetPlaceholder) => { @@ -389,6 +464,10 @@ impl AllocatedAbstractInstructionSet { // Far jumps must be handled separately, as they require two instructions. Either::Right(Jump { .. }) => 1, + // A switch expands to AddrDataId (2 opcodes) + scale descriminant by word size (1 opcode) + // + add descriminant (1 opcode) + load (1 opcode) + jump (1 opcode) = 6 opcodes + Either::Right(Switch { .. }) => 6, + Either::Right(Comment) => 0, Either::Right(DataSectionOffsetPlaceholder) => { @@ -427,7 +506,10 @@ impl AllocatedAbstractInstructionSet { for (op_idx, op) in self.ops.iter().enumerate() { // If we're seeing a control flow op then it's the end of the block. - if let Either::Right(ControlFlowOp::Label(_) | ControlFlowOp::Jump { .. }) = op.opcode { + if let Either::Right( + ControlFlowOp::Label(_) | ControlFlowOp::Jump { .. } | ControlFlowOp::Switch { .. }, + ) = op.opcode + { if let Some((lab, _idx, offs)) = cur_basic_block { // Insert the previous basic block. labelled_blocks.insert(lab, BasicBlock { offs }); @@ -510,7 +592,10 @@ impl AllocatedAbstractInstructionSet { for (op_idx, op) in self.ops.iter().enumerate() { // If we're seeing a control flow op then it's the end of the block. - if let Either::Right(ControlFlowOp::Label(_) | ControlFlowOp::Jump { .. }) = op.opcode { + if let Either::Right( + ControlFlowOp::Label(_) | ControlFlowOp::Jump { .. } | ControlFlowOp::Switch { .. }, + ) = op.opcode + { if let Some((lab, _idx, offs)) = cur_basic_block { // Insert the previous basic block. labelled_blocks.insert(lab, BasicBlock { offs }); diff --git a/sway-core/src/asm_generation/fuel/data_section.rs b/sway-core/src/asm_generation/fuel/data_section.rs index 592d4f54238..9a55b89f80b 100644 --- a/sway-core/src/asm_generation/fuel/data_section.rs +++ b/sway-core/src/asm_generation/fuel/data_section.rs @@ -34,6 +34,7 @@ pub enum Datum { Byte(u8), Word(u64), ByteArray(Vec), + WordArray(Vec), Slice(Vec), Collection(Vec), } @@ -67,6 +68,18 @@ impl Entry { } } + pub(crate) fn new_word_array( + words: Vec, + name: EntryName, + padding: Option, + ) -> Entry { + Entry { + padding: padding.unwrap_or(Padding::default_for_word_array(&words)), + value: Datum::WordArray(words), + name, + } + } + pub(crate) fn new_slice(bytes: Vec, name: EntryName, padding: Option) -> Entry { Entry { padding: padding.unwrap_or(Padding::default_for_byte_array(&bytes)), @@ -173,6 +186,7 @@ impl Entry { .copied() .take((bytes.len() + 7) & 0xfffffff8_usize) .collect(), + Datum::WordArray(words) => words.iter().flat_map(|w| w.to_be_bytes()).collect(), Datum::Collection(items) => items.iter().flat_map(|el| el.to_bytes()).collect(), }; @@ -400,6 +414,7 @@ impl fmt::Display for DataSection { Datum::Word(w) => format!(".word {w}"), Datum::ByteArray(bs) => display_bytes_for_data_section(bs, ".bytes"), Datum::Slice(bs) => display_bytes_for_data_section(bs, ".slice"), + Datum::WordArray(ws) => display_words_for_data_section(ws), Datum::Collection(els) => format!( ".collection {{ {} }}", els.iter() @@ -439,3 +454,20 @@ fn display_bytes_for_data_section(bs: &Vec, prefix: &str) -> String { } format!("{prefix}[{}] {hex_str} {chr_str}", bs.len()) } + +fn display_words_for_data_section(ws: &Vec) -> String { + let mut hex_str = String::new(); + let mut chr_str = String::new(); + for w in ws { + let bytes = w.to_be_bytes(); + for b in &bytes { + hex_str.push_str(format!("{b:02x} ").as_str()); + chr_str.push(if *b == b' ' || b.is_ascii_graphic() { + *b as char + } else { + '.' + }); + } + } + format!(".word_array [{}] {hex_str} {chr_str}", ws.len()) +} diff --git a/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs b/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs index aa460025f80..7f7fba77124 100644 --- a/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs +++ b/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs @@ -1069,9 +1069,6 @@ impl<'ir, 'eng> FuelAsmBuilder<'ir, 'eng> { return Ok(()); } - // TODO: Decide on a better limit for number of cases. - assert!(cases.len() < 20, "Too many switch cases to compile"); - // Sort the cases by their values to make range checking easier. let mut sorted_cases = cases.to_vec(); sorted_cases.sort_by_key(|(val, _)| *val); @@ -1098,15 +1095,62 @@ impl<'ir, 'eng> FuelAsmBuilder<'ir, 'eng> { .for_each(|(val, _)| *val -= min_case_value); } - let cases = cases + let sorted_cases: Vec<_> = sorted_cases .into_iter() - .map(|(val, BranchToWithArgs { block, args: _ })| (*val, self.block_to_label(block))) + .map(|(val, BranchToWithArgs { block, args: _ })| (val, self.block_to_label(&block))) .collect(); + + assert!( + sorted_cases[0].0 == 0, + "Lowest case value must be zero after adjustment" + ); + + let max_case_value = sorted_cases.last().unwrap().0; + // TODO: Decide on a better limit. + assert!( + max_case_value < 20, + "Jump table too large to compile switch" + ); + + // If the descriminant is greater than the highest case value, jump to default. + // TODO: For matching on enums where the compiler ensures that all variants are + // covered, this check can be skipped. + { + let cond_reg = self.reg_seqr.next(); + self.immediate_to_reg( + max_case_value, + cond_reg.clone(), + None, + "max_case_value", + None, + ); + self.cur_bytecode + .push(Op::jump_if_not_zero(cond_reg, default_label)); + } + + // For holes in the case values, i.e. non-contiguous case values, + // insert a jump to default for those. + let mut filled_sorted_cases = Vec::new(); + let mut next = 0; + for (case_value, case_label) in sorted_cases { + while next < case_value { + filled_sorted_cases.push(default_label); + next += 1; + } + filled_sorted_cases.push(case_label); + next += 1; + } + + // So far we've ensured that + // - The lowest case value is 0 (by subtracting min_case_value) + // - The highest case value is max_case_value (and jump to default if discrim > max) + // - Any holes in the case values jump to default + // Now we can emit the switch instruction. + let num_cases = filled_sorted_cases.len(); self.cur_bytecode.push(Op::switch_comment( discrim_reg, - cases, - default_label, - format!("switch with {} cases", sorted_cases.len()), + filled_sorted_cases, + format!("switch with {} cases", num_cases), )); Ok(()) diff --git a/sway-core/src/asm_generation/fuel/optimizations/constant_propagate.rs b/sway-core/src/asm_generation/fuel/optimizations/constant_propagate.rs index 9cc74575677..d5637a23498 100644 --- a/sway-core/src/asm_generation/fuel/optimizations/constant_propagate.rs +++ b/sway-core/src/asm_generation/fuel/optimizations/constant_propagate.rs @@ -217,6 +217,8 @@ impl AbstractInstructionSet { JumpType::Call => ResetKnown::Full, _ => ResetKnown::Defs, }, + // Same as non-call jumps. + ControlFlowOp::Switch { .. } => ResetKnown::Defs, // These ops mark their outputs properly and cause no control-flow effects ControlFlowOp::Comment | ControlFlowOp::ConfigurablesOffsetPlaceholder diff --git a/sway-core/src/asm_lang/mod.rs b/sway-core/src/asm_lang/mod.rs index 7f7e4310043..0871da08cdf 100644 --- a/sway-core/src/asm_lang/mod.rs +++ b/sway-core/src/asm_lang/mod.rs @@ -282,9 +282,12 @@ impl Op { } /// Switch - pub(crate) fn switch(on: VirtualRegister, cases: Vec<(u64, Label)>, default: Label) -> Self { + pub(crate) fn switch(on: VirtualRegister, cases: Vec