diff --git a/clarity/src/vm/analysis/type_checker/v2_1/mod.rs b/clarity/src/vm/analysis/type_checker/v2_1/mod.rs index e8300c0f1..7caf775c1 100644 --- a/clarity/src/vm/analysis/type_checker/v2_1/mod.rs +++ b/clarity/src/vm/analysis/type_checker/v2_1/mod.rs @@ -37,6 +37,7 @@ use crate::vm::costs::{ analysis_typecheck_cost, cost_functions, runtime_cost, ClarityCostFunctionReference, CostErrors, CostOverflowingMath, CostTracker, ExecutionCost, LimitedCostTracker, }; +use crate::vm::diagnostic::Diagnostic; use crate::vm::functions::define::DefineFunctionsParsed; use crate::vm::functions::NativeFunctions; use crate::vm::representations::SymbolicExpressionType::{ @@ -151,7 +152,130 @@ impl TypeChecker<'_, '_> { pub type TypeResult = CheckResult; +pub fn compute_typecheck_cost( + track: &mut T, + t1: &TypeSignature, + t2: &TypeSignature, +) -> Result { + let t1_size = t1.type_size().map_err(|_| CostErrors::CostOverflow)?; + let t2_size = t2.type_size().map_err(|_| CostErrors::CostOverflow)?; + track.compute_cost( + ClarityCostFunction::AnalysisTypeCheck, + &[std::cmp::max(t1_size, t2_size).into()], + ) +} + +pub fn check_argument_len(expected: usize, args_len: usize) -> Result<(), CheckErrors> { + if args_len != expected { + Err(CheckErrors::IncorrectArgumentCount(expected, args_len)) + } else { + Ok(()) + } +} + impl FunctionType { + pub fn check_args_visitor_2_1( + &self, + accounting: &mut T, + arg_type: &TypeSignature, + arg_index: usize, + accumulated_type: Option<&TypeSignature>, + ) -> ( + Option>, + CheckResult>, + ) { + match self { + // variadic stops checking cost at the first error... + FunctionType::Variadic(expected_type, _) => { + let cost = Some(compute_typecheck_cost(accounting, expected_type, arg_type)); + let admitted = match expected_type.admits_type(&StacksEpochId::Epoch21, arg_type) { + Ok(admitted) => admitted, + Err(e) => return (cost, Err(e.into())), + }; + if !admitted { + return ( + cost, + Err(CheckErrors::TypeError(expected_type.clone(), arg_type.clone()).into()), + ); + } + (cost, Ok(None)) + } + FunctionType::ArithmeticVariadic => { + let cost = Some(compute_typecheck_cost( + accounting, + &TypeSignature::IntType, + arg_type, + )); + if arg_index == 0 { + let return_type = match arg_type { + TypeSignature::IntType => Ok(Some(TypeSignature::IntType)), + TypeSignature::UIntType => Ok(Some(TypeSignature::UIntType)), + _ => Err(CheckErrors::UnionTypeError( + vec![TypeSignature::IntType, TypeSignature::UIntType], + arg_type.clone(), + ) + .into()), + }; + (cost, return_type) + } else { + let return_type = accumulated_type + .ok_or_else(|| CheckErrors::Expects("Failed to set accumulated type for arg indices >= 1 in variadic arithmetic".into()).into()); + let check_result = return_type.and_then(|return_type| { + if arg_type != return_type { + Err( + CheckErrors::TypeError(return_type.clone(), arg_type.clone()) + .into(), + ) + } else { + Ok(None) + } + }); + (cost, check_result) + } + } + // For the fixed function types, the visitor will just + // tell the processor that any results greater than the args len + // do not need to be stored, because an error will occur before + // further checking anyways + FunctionType::Fixed(FixedFunction { + args: arg_types, .. + }) => { + if arg_index >= arg_types.len() { + // note: argument count will be wrong? + return ( + None, + Err(CheckErrors::IncorrectArgumentCount(arg_types.len(), arg_index).into()), + ); + } + return (None, Ok(None)); + } + // For the following function types, the visitor will just + // tell the processor that any results greater than len 1 or 2 + // do not need to be stored, because an error will occur before + // further checking anyways + FunctionType::ArithmeticUnary | FunctionType::UnionArgs(..) => { + if arg_index >= 1 { + return ( + None, + Err(CheckErrors::IncorrectArgumentCount(1, arg_index).into()), + ); + } + return (None, Ok(None)); + } + FunctionType::ArithmeticBinary + | FunctionType::ArithmeticComparison + | FunctionType::Binary(..) => { + if arg_index >= 2 { + return ( + None, + Err(CheckErrors::IncorrectArgumentCount(2, arg_index).into()), + ); + } + return (None, Ok(None)); + } + } + } + pub fn check_args_2_1( &self, accounting: &mut T, @@ -1017,17 +1141,23 @@ impl<'a, 'b> TypeChecker<'a, 'b> { args: &[SymbolicExpression], context: &TypingContext, ) -> TypeResult { - let mut types_returned = self.type_check_all(args, context)?; - - let last_return = types_returned - .pop() - .ok_or(CheckError::new(CheckErrors::CheckerImplementationFailure))?; - - for type_return in types_returned.iter() { - if type_return.is_response_type() { - return Err(CheckErrors::UncheckedIntermediaryResponses.into()); + let mut last_return = None; + let mut return_failure = Ok(()); + for ix in 0..args.len() { + let type_return = self.type_check(&args[ix], context)?; + if ix + 1 < args.len() { + if type_return.is_response_type() { + return_failure = Err(CheckErrors::UncheckedIntermediaryResponses); + } + } else { + last_return = Some(type_return); } } + + let last_return = last_return + .ok_or_else(|| CheckError::new(CheckErrors::CheckerImplementationFailure))?; + return_failure?; + Ok(last_return) } @@ -1052,8 +1182,56 @@ impl<'a, 'b> TypeChecker<'a, 'b> { epoch: StacksEpochId, clarity_version: ClarityVersion, ) -> TypeResult { - let typed_args = self.type_check_all(args, context)?; - func_type.check_args(self, &typed_args, epoch, clarity_version) + if epoch <= StacksEpochId::Epoch2_05 { + let typed_args = self.type_check_all(args, context)?; + return func_type.check_args(self, &typed_args, epoch, clarity_version); + } + // use func_type visitor pattern + let mut accumulated_type = None; + let mut total_costs = vec![]; + let mut check_result = Ok(()); + let mut accumulated_types = Vec::new(); + for (arg_ix, arg_expr) in args.iter().enumerate() { + let arg_type = self.type_check(arg_expr, context)?; + if check_result.is_ok() { + let (costs, result) = func_type.check_args_visitor_2_1( + self, + &arg_type, + arg_ix, + accumulated_type.as_ref(), + ); + // add the accumulated type and total cost *before* + // checking for an error: we want the subsequent error handling + // to account for this cost + accumulated_types.push(arg_type); + total_costs.extend(costs); + + match result { + Ok(Some(returned_type)) => { + accumulated_type = Some(returned_type); + } + Ok(None) => {} + Err(e) => { + check_result = Err(e); + } + }; + } + } + if let Err(mut check_error) = check_result { + if let CheckErrors::IncorrectArgumentCount(expected, _actual) = check_error.err { + check_error.err = CheckErrors::IncorrectArgumentCount(expected, args.len()); + check_error.diagnostic = Diagnostic::err(&check_error.err) + } + // accumulate the checking costs + // the reason we do this now (instead of within the loop) is for backwards compatibility + for cost in total_costs.into_iter() { + self.add_cost(cost?)?; + } + + return Err(check_error); + } + // otherwise, just invoke the normal checking routine + func_type.check_args(self, &accumulated_types, epoch, clarity_version) } fn get_function_type(&self, function_name: &str) -> Option { diff --git a/clarity/src/vm/analysis/type_checker/v2_1/natives/mod.rs b/clarity/src/vm/analysis/type_checker/v2_1/natives/mod.rs index c5aefb65e..e5fc32c67 100644 --- a/clarity/src/vm/analysis/type_checker/v2_1/natives/mod.rs +++ b/clarity/src/vm/analysis/type_checker/v2_1/natives/mod.rs @@ -17,8 +17,8 @@ use stacks_common::types::StacksEpochId; use super::{ - check_argument_count, check_arguments_at_least, check_arguments_at_most, no_type, TypeChecker, - TypeResult, TypingContext, + check_argument_count, check_arguments_at_least, check_arguments_at_most, + compute_typecheck_cost, no_type, TypeChecker, TypeResult, TypingContext, }; use crate::vm::analysis::errors::{CheckError, CheckErrors, CheckResult}; use crate::vm::costs::cost_functions::ClarityCostFunction; @@ -61,14 +61,43 @@ fn check_special_list_cons( args: &[SymbolicExpression], context: &TypingContext, ) -> TypeResult { - let typed_args = checker.type_check_all(args, context)?; - for type_arg in typed_args.iter() { - runtime_cost( - ClarityCostFunction::AnalysisListItemsCheck, - checker, - type_arg.type_size()?, - )?; + let mut result = Vec::with_capacity(args.len()); + let mut entries_size: Option = Some(0); + let mut costs = Vec::with_capacity(args.len()); + + for arg in args.iter() { + // don't use map here, since type_check has side-effects. + let checked = checker.type_check(arg, context)?; + let cost = checked.type_size().and_then(|ty_size| { + checker + .compute_cost( + ClarityCostFunction::AnalysisListItemsCheck, + &[ty_size.into()], + ) + .map_err(CheckErrors::from) + }); + costs.push(cost); + + if let Some(cur_size) = entries_size.clone() { + entries_size = cur_size.checked_add(checked.size()?); + } + if let Some(cur_size) = entries_size { + if cur_size > MAX_VALUE_SIZE { + entries_size = None; + } + } + if entries_size.is_some() { + result.push(checked); + } } + + for cost in costs.into_iter() { + checker.add_cost(cost?)?; + } + if entries_size.is_none() { + return Err(CheckErrors::ValueTooLarge.into()); + } + let typed_args = result; TypeSignature::parent_list_type(&typed_args) .map_err(|x| x.into()) .map(TypeSignature::from) @@ -202,6 +231,9 @@ pub fn check_special_tuple_cons( args.len(), )?; + let mut type_size = 0u32; + let mut cons_error = Ok(()); + handle_binding_list(args, |var_name, var_sexp| { checker.type_check(var_sexp, context).and_then(|var_type| { runtime_cost( @@ -209,11 +241,21 @@ pub fn check_special_tuple_cons( checker, var_type.type_size()?, )?; - tuple_type_data.push((var_name.clone(), var_type)); + if type_size < MAX_VALUE_SIZE { + type_size = type_size + .saturating_add(var_name.len() as u32) + .saturating_add(var_name.len() as u32) + .saturating_add(var_type.type_size()?) + .saturating_add(var_type.size()?); + tuple_type_data.push((var_name.clone(), var_type)); + } else { + cons_error = Err(CheckErrors::BadTupleConstruction); + } Ok(()) }) })?; + cons_error?; let tuple_signature = TupleTypeSignature::try_from(tuple_type_data) .map_err(|_e| CheckErrors::BadTupleConstruction)?; @@ -338,15 +380,33 @@ fn check_special_equals( ) -> TypeResult { check_arguments_at_least(1, args)?; - let arg_types = checker.type_check_all(args, context)?; + let mut arg_type = None; + let mut costs = Vec::with_capacity(args.len()); - let mut arg_type = arg_types[0].clone(); - for x_type in arg_types.into_iter() { - analysis_typecheck_cost(checker, &x_type, &arg_type)?; - arg_type = TypeSignature::least_supertype(&StacksEpochId::Epoch21, &x_type, &arg_type) - .map_err(|_| CheckErrors::TypeError(x_type, arg_type))?; + for arg in args.iter() { + let x_type = checker.type_check(arg, context)?; + if arg_type.is_none() { + arg_type = Some(Ok(x_type.clone())); + } + if let Some(Ok(cur_type)) = arg_type { + let cost = compute_typecheck_cost(checker, &x_type, &cur_type); + costs.push(cost); + arg_type = Some( + TypeSignature::least_supertype(&StacksEpochId::Epoch21, &x_type, &cur_type) + .map_err(|_| CheckErrors::TypeError(x_type, cur_type)), + ); + } } + for cost in costs.into_iter() { + checker.add_cost(cost?)?; + } + + // check if there was a least supertype failure. + arg_type.ok_or_else(|| { + CheckErrors::Expects("Arg type should be set because arguments checked for >= 1".into()) + })??; + Ok(TypeSignature::BoolType) } diff --git a/clarity/src/vm/analysis/type_checker/v2_1/natives/sequences.rs b/clarity/src/vm/analysis/type_checker/v2_1/natives/sequences.rs index 090b259a2..c1b3aabb1 100644 --- a/clarity/src/vm/analysis/type_checker/v2_1/natives/sequences.rs +++ b/clarity/src/vm/analysis/type_checker/v2_1/natives/sequences.rs @@ -22,7 +22,8 @@ use crate::vm::analysis::type_checker::v2_1::{ TypeResult, TypingContext, }; use crate::vm::costs::cost_functions::ClarityCostFunction; -use crate::vm::costs::{analysis_typecheck_cost, cost_functions, runtime_cost}; +use crate::vm::costs::{analysis_typecheck_cost, cost_functions, runtime_cost, CostTracker}; +use crate::vm::diagnostic::Diagnostic; use crate::vm::functions::NativeFunctions; use crate::vm::representations::{SymbolicExpression, SymbolicExpressionType}; pub use crate::vm::types::signatures::{BufferLength, ListTypeData, StringUTF8Length, BUFF_1}; @@ -73,9 +74,15 @@ pub fn check_special_map( )?; let iter = args[1..].iter(); - let mut func_args = Vec::with_capacity(iter.len()); let mut min_args = u32::MAX; - for arg in iter { + + // use func_type visitor pattern + let mut accumulated_type = None; + let mut total_costs = vec![]; + let mut check_result = Ok(()); + let mut accumulated_types = Vec::new(); + + for (arg_ix, arg) in iter.enumerate() { let argument_type = checker.type_check(arg, context)?; let entry_type = match argument_type { TypeSignature::SequenceType(sequence) => { @@ -101,11 +108,52 @@ pub fn check_special_map( return Err(CheckErrors::ExpectedSequence(argument_type).into()); } }; - func_args.push(entry_type); + + if check_result.is_ok() { + let (costs, result) = function_type.check_args_visitor_2_1( + checker, + &entry_type, + arg_ix, + accumulated_type.as_ref(), + ); + // add the accumulated type and total cost *before* + // checking for an error: we want the subsequent error handling + // to account for this cost + accumulated_types.push(entry_type); + total_costs.extend(costs); + + match result { + Ok(Some(returned_type)) => { + accumulated_type = Some(returned_type); + } + Ok(None) => {} + Err(e) => { + check_result = Err(e); + } + }; + } } - let mapped_type = - function_type.check_args(checker, &func_args, context.epoch, context.clarity_version)?; + if let Err(mut check_error) = check_result { + if let CheckErrors::IncorrectArgumentCount(expected, _actual) = check_error.err { + check_error.err = + CheckErrors::IncorrectArgumentCount(expected, args.len().saturating_sub(1)); + check_error.diagnostic = Diagnostic::err(&check_error.err) + } + // accumulate the checking costs + for cost in total_costs.into_iter() { + checker.add_cost(cost?)?; + } + + return Err(check_error); + } + + let mapped_type = function_type.check_args( + checker, + &accumulated_types, + context.epoch, + context.clarity_version, + )?; TypeSignature::list_of(mapped_type, min_args) .map_err(|_| CheckErrors::ConstructedListTooLarge.into()) } diff --git a/clarity/src/vm/types/signatures.rs b/clarity/src/vm/types/signatures.rs index c9971f97a..293c36fd5 100644 --- a/clarity/src/vm/types/signatures.rs +++ b/clarity/src/vm/types/signatures.rs @@ -15,8 +15,10 @@ // along with this program. If not, see . use std::collections::btree_map::Entry; -use std::collections::{hash_map, BTreeMap, HashMap}; +use std::collections::{hash_map, BTreeMap}; use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; use std::{cmp, fmt}; // TypeSignatures @@ -76,7 +78,36 @@ impl AssetIdentifier { #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TupleTypeSignature { - type_map: HashMap, + #[serde(with = "tuple_type_map_serde")] + type_map: Arc>, +} + +mod tuple_type_map_serde { + use std::collections::BTreeMap; + use std::ops::Deref; + use std::sync::Arc; + + use serde::{Deserializer, Serializer}; + + use super::TypeSignature; + use crate::vm::ClarityName; + + pub fn serialize( + map: &Arc>, + ser: S, + ) -> Result { + serde::Serialize::serialize(map.deref(), ser) + } + + pub fn deserialize<'de, D>( + deser: D, + ) -> Result>, D::Error> + where + D: Deserializer<'de>, + { + let map = serde::Deserialize::deserialize(deser)?; + Ok(Arc::new(map)) + } } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] @@ -787,12 +818,12 @@ impl TypeSignature { inner_type.1.canonicalize_v2_1(), ))), TupleType(ref tuple_sig) => { - let mut canonicalized_fields = HashMap::new(); + let mut canonicalized_fields = BTreeMap::new(); for (field_name, field_type) in tuple_sig.get_type_map() { canonicalized_fields.insert(field_name.clone(), field_type.canonicalize_v2_1()); } TypeSignature::from(TupleTypeSignature { - type_map: canonicalized_fields, + type_map: Arc::new(canonicalized_fields), }) } TraitReferenceType(trait_id) => CallableType(CallableSubtype::Trait(trait_id.clone())), @@ -851,9 +882,9 @@ impl TryFrom> for TupleTypeSignature { return Err(CheckErrors::EmptyTuplesNotAllowed); } - let mut type_map = HashMap::new(); + let mut type_map = BTreeMap::new(); for (name, type_info) in type_data.into_iter() { - if let hash_map::Entry::Vacant(e) = type_map.entry(name.clone()) { + if let Entry::Vacant(e) = type_map.entry(name.clone()) { e.insert(type_info); } else { return Err(CheckErrors::NameAlreadyUsed(name.into())); @@ -874,30 +905,7 @@ impl TryFrom> for TupleTypeSignature { return Err(CheckErrors::TypeSignatureTooDeep); } } - let type_map = type_map.into_iter().collect(); - let result = TupleTypeSignature { type_map }; - let would_be_size = result - .inner_size()? - .ok_or_else(|| CheckErrors::ValueTooLarge)?; - if would_be_size > MAX_VALUE_SIZE { - Err(CheckErrors::ValueTooLarge) - } else { - Ok(result) - } - } -} - -impl TryFrom> for TupleTypeSignature { - type Error = CheckErrors; - fn try_from(type_map: HashMap) -> Result { - if type_map.is_empty() { - return Err(CheckErrors::EmptyTuplesNotAllowed); - } - for child_sig in type_map.values() { - if (1 + child_sig.depth()) > MAX_TYPE_DEPTH { - return Err(CheckErrors::TypeSignatureTooDeep); - } - } + let type_map = Arc::new(type_map.into_iter().collect()); let result = TupleTypeSignature { type_map }; let would_be_size = result .inner_size()? @@ -925,7 +933,7 @@ impl TupleTypeSignature { self.type_map.get(field) } - pub fn get_type_map(&self) -> &HashMap { + pub fn get_type_map(&self) -> &BTreeMap { &self.type_map } @@ -961,7 +969,7 @@ impl TupleTypeSignature { } pub fn shallow_merge(&mut self, update: &mut TupleTypeSignature) { - self.type_map.extend(update.type_map.drain()); + Arc::make_mut(&mut self.type_map).append(Arc::make_mut(&mut update.type_map)); } } diff --git a/stacks-signer/src/client/stacks_client.rs b/stacks-signer/src/client/stacks_client.rs index 2bd17cf22..145c83ddd 100644 --- a/stacks-signer/src/client/stacks_client.rs +++ b/stacks-signer/src/client/stacks_client.rs @@ -722,7 +722,7 @@ impl StacksClient { #[cfg(test)] mod tests { - use std::collections::HashMap; + use std::collections::BTreeMap; use std::io::{BufWriter, Write}; use std::thread::spawn; @@ -1080,7 +1080,7 @@ mod tests { (ClarityName::from("signer"), TypeSignature::PrincipalType), ] .into_iter() - .collect::>() + .collect::>() .try_into() .unwrap();