mirror of
				https://github.com/YosysHQ/yosys
				synced 2025-11-03 21:09:12 +00:00 
			
		
		
		
	Emit valid SMT for stateful designs, fix some cells
This commit is contained in:
		
							parent
							
								
									f0f436cbe7
								
							
						
					
					
						commit
						5780357cd9
					
				
					 4 changed files with 306 additions and 183 deletions
				
			
		| 
						 | 
				
			
			@ -23,7 +23,7 @@
 | 
			
		|||
USING_YOSYS_NAMESPACE
 | 
			
		||||
PRIVATE_NAMESPACE_BEGIN
 | 
			
		||||
 | 
			
		||||
const char illegal_characters[] = "$\\";
 | 
			
		||||
const char illegal_characters[] = "#:\\";
 | 
			
		||||
const char *reserved_keywords[] = {};
 | 
			
		||||
 | 
			
		||||
struct SmtScope {
 | 
			
		||||
| 
						 | 
				
			
			@ -36,7 +36,8 @@ struct SmtScope {
 | 
			
		|||
 | 
			
		||||
  std::string insert(IdString id)
 | 
			
		||||
  {
 | 
			
		||||
		std::string name = RTLIL::unescape_id(id);
 | 
			
		||||
 | 
			
		||||
    std::string name = scope(id);
 | 
			
		||||
    if (used_names.count(name) == 0) {
 | 
			
		||||
      used_names.insert(name);
 | 
			
		||||
      name_map[id] = name;
 | 
			
		||||
| 
						 | 
				
			
			@ -106,7 +107,7 @@ template <class NodeNames> struct SmtPrintVisitor {
 | 
			
		|||
 | 
			
		||||
  std::string slice(Node, Node a, int, int offset, int out_width)
 | 
			
		||||
  {
 | 
			
		||||
		return format("(_ extract %1 %2 %0)", np(a), offset, offset + out_width - 1);
 | 
			
		||||
    return format("((_ extract %2 %1) %0)", np(a), offset, offset + out_width - 1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string zero_extend(Node, Node a, int, int out_width) { return format("((_ zero_extend %1) %0)", np(a), out_width - a.width()); }
 | 
			
		||||
| 
						 | 
				
			
			@ -135,7 +136,12 @@ template <class NodeNames> struct SmtPrintVisitor {
 | 
			
		|||
 | 
			
		||||
  std::string unary_minus(Node, Node a, int) { return format("(bvneg %0)", np(a)); }
 | 
			
		||||
 | 
			
		||||
	std::string reduce_and(Node, Node a, int) { return format("(= %0 #b1)", np(a)); }
 | 
			
		||||
  std::string reduce_and(Node, Node a, int) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    // We use ite to set the result to bit vector, to ensure appropriate type
 | 
			
		||||
    ss << "(ite (= " << np(a) << " #b" << std::string(a.width(), '1') << ") #b1 #b0)";
 | 
			
		||||
    return ss.str();
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
  std::string reduce_or(Node, Node a, int)
 | 
			
		||||
  {
 | 
			
		||||
| 
						 | 
				
			
			@ -145,27 +151,104 @@ template <class NodeNames> struct SmtPrintVisitor {
 | 
			
		|||
    return ss.str();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
	std::string reduce_xor(Node, Node a, int) { return format("(bvxor_reduce %0)", np(a)); }
 | 
			
		||||
  std::string reduce_xor(Node, Node a, int) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    ss << "(bvxor ";
 | 
			
		||||
    for (int i = 0; i < a.width(); ++i) {
 | 
			
		||||
      if (i > 0) ss << " ";
 | 
			
		||||
      ss << "((_ extract " << i << " " << i << ") " << np(a) << ")";
 | 
			
		||||
    }
 | 
			
		||||
    ss << ")";
 | 
			
		||||
    return ss.str();
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
	std::string equal(Node, Node a, Node b, int) { return format("(= %0 %1)", np(a), np(b)); }
 | 
			
		||||
  std::string equal(Node, Node a, Node b, int) {
 | 
			
		||||
    return format("(ite (= %0 %1) #b1 #b0)", np(a), np(b));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
	std::string not_equal(Node, Node a, Node b, int) { return format("(distinct %0 %1)", np(a), np(b)); }
 | 
			
		||||
  std::string not_equal(Node, Node a, Node b, int) {
 | 
			
		||||
    return format("(ite (distinct %0 %1) #b1 #b0)", np(a), np(b));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
	std::string signed_greater_than(Node, Node a, Node b, int) { return format("(bvsgt %0 %1)", np(a), np(b)); }
 | 
			
		||||
  std::string signed_greater_than(Node, Node a, Node b, int) { 
 | 
			
		||||
    return format("(ite (bvsgt %0 %1) #b1 #b0)", np(a), np(b)); 
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
	std::string signed_greater_equal(Node, Node a, Node b, int) { return format("(bvsge %0 %1)", np(a), np(b)); }
 | 
			
		||||
  std::string signed_greater_equal(Node, Node a, Node b, int) {
 | 
			
		||||
    return format("(ite (bvsge %0 %1) #b1 #b0)", np(a), np(b));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
	std::string unsigned_greater_than(Node, Node a, Node b, int) { return format("(bvugt %0 %1)", np(a), np(b)); }
 | 
			
		||||
  std::string unsigned_greater_than(Node, Node a, Node b, int) { 
 | 
			
		||||
    return format("(ite (bvugt %0 %1) #b1 #b0)", np(a), np(b)); 
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
	std::string unsigned_greater_equal(Node, Node a, Node b, int) { return format("(bvuge %0 %1)", np(a), np(b)); }
 | 
			
		||||
  std::string unsigned_greater_equal(Node, Node a, Node b, int) { 
 | 
			
		||||
    return format("(ite (bvuge %0 %1) #b1 #b0)", np(a), np(b)); 
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
	std::string logical_shift_left(Node, Node a, Node b, int, int) { return format("(bvshl %0 %1)", np(a), np(b)); }
 | 
			
		||||
  std::string logical_shift_left(Node, Node a, Node b, int, int) {
 | 
			
		||||
    // Get the bit-widths of a and b
 | 
			
		||||
    int bit_width_a = a.width();
 | 
			
		||||
    int bit_width_b = b.width();
 | 
			
		||||
 | 
			
		||||
	std::string logical_shift_right(Node, Node a, Node b, int, int) { return format("(bvlshr %0 %1)", np(a), np(b)); }
 | 
			
		||||
    // Extend b to match the bit-width of a if necessary
 | 
			
		||||
    std::ostringstream oss;
 | 
			
		||||
    if (bit_width_a > bit_width_b) {
 | 
			
		||||
      oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")";
 | 
			
		||||
    } else {
 | 
			
		||||
      oss << np(b);  // No extension needed if b's width is already sufficient
 | 
			
		||||
    }
 | 
			
		||||
    std::string b_extended = oss.str();
 | 
			
		||||
 | 
			
		||||
	std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { return format("(bvashr %0 %1)", np(a), np(b)); }
 | 
			
		||||
    // Format the bvshl operation with the extended b
 | 
			
		||||
    oss.str(""); // Clear the stringstream
 | 
			
		||||
    oss << "(bvshl " << np(a) << " " << b_extended << ")";
 | 
			
		||||
    return oss.str();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
	std::string mux(Node, Node a, Node b, Node s, int) { return format("(ite %2 %0 %1)", np(a), np(b), np(s)); }
 | 
			
		||||
  std::string logical_shift_right(Node, Node a, Node b, int, int) {
 | 
			
		||||
    // Get the bit-widths of a and b
 | 
			
		||||
    int bit_width_a = a.width();
 | 
			
		||||
    int bit_width_b = b.width();
 | 
			
		||||
 | 
			
		||||
    // Extend b to match the bit-width of a if necessary
 | 
			
		||||
    std::ostringstream oss;
 | 
			
		||||
    if (bit_width_a > bit_width_b) {
 | 
			
		||||
      oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")";
 | 
			
		||||
    } else {
 | 
			
		||||
      oss << np(b);  // No extension needed if b's width is already sufficient
 | 
			
		||||
    }
 | 
			
		||||
    std::string b_extended = oss.str();
 | 
			
		||||
 | 
			
		||||
    // Format the bvlshr operation with the extended b
 | 
			
		||||
    oss.str(""); // Clear the stringstream
 | 
			
		||||
    oss << "(bvlshr " << np(a) << " " << b_extended << ")";
 | 
			
		||||
    return oss.str();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string arithmetic_shift_right(Node, Node a, Node b, int, int) {
 | 
			
		||||
    // Get the bit-widths of a and b
 | 
			
		||||
    int bit_width_a = a.width();
 | 
			
		||||
    int bit_width_b = b.width();
 | 
			
		||||
 | 
			
		||||
    // Extend b to match the bit-width of a if necessary
 | 
			
		||||
    std::ostringstream oss;
 | 
			
		||||
    if (bit_width_a > bit_width_b) {
 | 
			
		||||
      oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")";
 | 
			
		||||
    } else {
 | 
			
		||||
      oss << np(b);  // No extension needed if b's width is already sufficient
 | 
			
		||||
    }
 | 
			
		||||
    std::string b_extended = oss.str();
 | 
			
		||||
 | 
			
		||||
    // Format the bvashr operation with the extended b
 | 
			
		||||
    oss.str(""); // Clear the stringstream
 | 
			
		||||
    oss << "(bvashr " << np(a) << " " << b_extended << ")";
 | 
			
		||||
    return oss.str();
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
  std::string mux(Node, Node a, Node b, Node s, int) {
 | 
			
		||||
    return format("(ite (= %2 #b1) %0 %1)", np(a), np(b), np(s));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string pmux(Node, Node a, Node b, Node s, int, int)
 | 
			
		||||
  {
 | 
			
		||||
| 
						 | 
				
			
			@ -173,11 +256,11 @@ template <class NodeNames> struct SmtPrintVisitor {
 | 
			
		|||
    return format("(pmux %0 %1 %2)", np(a), np(b), np(s));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
	std::string constant(Node, RTLIL::Const value) { return format("#b%1", value.as_string()); }
 | 
			
		||||
  std::string constant(Node, RTLIL::Const value) { return format("#b%0", value.as_string()); }
 | 
			
		||||
 | 
			
		||||
  std::string input(Node, IdString name) { return format("%0", scope[name]); }
 | 
			
		||||
 | 
			
		||||
	std::string state(Node, IdString name) { return format("%0", scope[name]); }
 | 
			
		||||
  std::string state(Node, IdString name) { return format("(%0 current_state)", scope[name]); }
 | 
			
		||||
 | 
			
		||||
  std::string memory_read(Node, Node mem, Node addr, int, int) { return format("(select %0 %1)", np(mem), np(addr)); }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -195,15 +278,17 @@ struct SmtModule {
 | 
			
		|||
 | 
			
		||||
  void write(std::ostream &out)
 | 
			
		||||
  {
 | 
			
		||||
    const bool stateful = ir.state().size() != 0;
 | 
			
		||||
    SmtWriter writer(out);
 | 
			
		||||
 | 
			
		||||
		writer.print("(declare-fun %s () Bool)\n", name.c_str());
 | 
			
		||||
    writer.print("(declare-fun %s () Bool)\n\n", name.c_str());
 | 
			
		||||
 | 
			
		||||
    writer.print("(declare-datatypes ()  ((Inputs    (mk_inputs");
 | 
			
		||||
    for (const auto &input : ir.inputs()) {
 | 
			
		||||
      std::string input_name = scope[input.first];
 | 
			
		||||
      writer.print(" (%s (_ BitVec %d))", input_name.c_str(), input.second.width());
 | 
			
		||||
    }
 | 
			
		||||
		writer.print("))))\n");
 | 
			
		||||
    writer.print("))))\n\n");
 | 
			
		||||
 | 
			
		||||
    writer.print("(declare-datatypes () ((Outputs    (mk_outputs");
 | 
			
		||||
    for (const auto &output : ir.outputs()) {
 | 
			
		||||
| 
						 | 
				
			
			@ -212,9 +297,22 @@ struct SmtModule {
 | 
			
		|||
    }
 | 
			
		||||
    writer.print("))))\n");
 | 
			
		||||
 | 
			
		||||
		writer.print("(declare-fun state () (_ BitVec 1))\n");
 | 
			
		||||
    if (stateful) {
 | 
			
		||||
      writer.print("(declare-datatypes ()  ((State   (mk_state");
 | 
			
		||||
      for (const auto &state : ir.state()) {
 | 
			
		||||
	std::string state_name = scope[state.first];
 | 
			
		||||
	writer.print(" (%s (_ BitVec %d))", state_name.c_str(), state.second.width());
 | 
			
		||||
      }
 | 
			
		||||
      writer.print("))))\n");
 | 
			
		||||
 | 
			
		||||
      writer.print("(declare-datatypes ()  ((Pair   (mk-pair (outputs Outputs) (next_state State)))))\n");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (stateful)
 | 
			
		||||
      writer.print("(define-fun %s_step ((current_state State) (inputs Inputs)) Pair", name.c_str());
 | 
			
		||||
    else
 | 
			
		||||
      writer.print("(define-fun %s_step ((inputs Inputs)) Outputs", name.c_str());
 | 
			
		||||
 | 
			
		||||
		writer.print("(define-fun %s_step ((state (_ BitVec 1)) (inputs Inputs)) Outputs", name.c_str());
 | 
			
		||||
    writer.print(" (let (");
 | 
			
		||||
    for (const auto &input : ir.inputs()) {
 | 
			
		||||
      std::string input_name = scope[input.first];
 | 
			
		||||
| 
						 | 
				
			
			@ -237,28 +335,45 @@ struct SmtModule {
 | 
			
		|||
      writer.print(" (let ( (%s %s))", node_name.c_str(), node_expr.c_str());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
		writer.print("    (let (");
 | 
			
		||||
		for (const auto &output : ir.outputs()) {
 | 
			
		||||
			std::string output_name = scope[output.first];
 | 
			
		||||
			const std::string output_assignment = ir.get_output_node(output.first).name().c_str();
 | 
			
		||||
			writer.print("      (%s %s)", output_name.c_str(), output_assignment.substr(1).c_str());
 | 
			
		||||
    if (stateful) {
 | 
			
		||||
      writer.print(" (let ( (next_state (mk_state ");
 | 
			
		||||
      for (const auto &state : ir.state()) {
 | 
			
		||||
	std::string state_name = scope[state.first];
 | 
			
		||||
	const std::string state_assignment = ir.get_state_next_node(state.first).name().c_str();
 | 
			
		||||
	writer.print(" %s", state_assignment.substr(1).c_str());
 | 
			
		||||
      }
 | 
			
		||||
		writer.print("    )");
 | 
			
		||||
		writer.print("      (mk_outputs");
 | 
			
		||||
      writer.print(" )))");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (stateful) {
 | 
			
		||||
      writer.print(" (let ( (outputs (mk_outputs ");
 | 
			
		||||
      for (const auto &output : ir.outputs()) {
 | 
			
		||||
	std::string output_name = scope[output.first];
 | 
			
		||||
	writer.print(" %s", output_name.c_str());
 | 
			
		||||
      }
 | 
			
		||||
		writer.print("      )");
 | 
			
		||||
		writer.print("    )");
 | 
			
		||||
      writer.print(" )))");
 | 
			
		||||
 | 
			
		||||
      writer.print("(mk-pair outputs next_state)");
 | 
			
		||||
    }
 | 
			
		||||
    else {
 | 
			
		||||
      writer.print(" (mk_outputs ");
 | 
			
		||||
      for (const auto &output : ir.outputs()) {
 | 
			
		||||
	std::string output_name = scope[output.first];
 | 
			
		||||
	writer.print(" %s", output_name.c_str());
 | 
			
		||||
      }
 | 
			
		||||
      writer.print(" )"); // Closing mk_outputs 
 | 
			
		||||
    }
 | 
			
		||||
    if (stateful) {
 | 
			
		||||
      writer.print(" )"); // Closing outputs let statement
 | 
			
		||||
      writer.print(" )"); // Closing next_state let statement
 | 
			
		||||
    }
 | 
			
		||||
    // Close the nested lets
 | 
			
		||||
    for (size_t i = 0; i < ir.size() - ir.inputs().size(); ++i) {
 | 
			
		||||
			writer.print("  )");
 | 
			
		||||
      writer.print(" )"); // Closing each node
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
		writer.print("  )");
 | 
			
		||||
		writer.print(")\n");
 | 
			
		||||
    writer.print(" )"); // Closing inputs let statement
 | 
			
		||||
    writer.print(")\n"); // Closing step function
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,6 +87,10 @@ run_smt_test() {
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
run_all_tests() {
 | 
			
		||||
    declare -A cxx_failing_files
 | 
			
		||||
    declare -A smt_failing_files
 | 
			
		||||
    declare -A cxx_successful_files
 | 
			
		||||
    declare -A smt_successful_files
 | 
			
		||||
    return_code=0
 | 
			
		||||
    for rtlil_file in rtlil/*.il; do
 | 
			
		||||
        run_cxx_test "$rtlil_file"
 | 
			
		||||
| 
						 | 
				
			
			@ -134,6 +138,8 @@ run_all_tests() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
run_smt_tests() {
 | 
			
		||||
    declare -A smt_failing_files
 | 
			
		||||
    declare -A smt_successful_files
 | 
			
		||||
    return_code=0
 | 
			
		||||
    for rtlil_file in rtlil/*.il; do
 | 
			
		||||
        run_smt_test "$rtlil_file"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -76,6 +76,8 @@ for lst in parsed_results:
 | 
			
		|||
            declarations = datatype_group[1][1:]  # Skip the first item (e.g., 'mk_inputs')
 | 
			
		||||
            if datatype_name == 'Inputs':
 | 
			
		||||
                for declaration in declarations:
 | 
			
		||||
                    print("My declaration")
 | 
			
		||||
                    print(declaration)
 | 
			
		||||
                    input_name = declaration[0]
 | 
			
		||||
                    bitvec_size = declaration[1][2]
 | 
			
		||||
                    inputs[input_name] = int(bitvec_size)
 | 
			
		||||
| 
						 | 
				
			
			@ -103,7 +105,7 @@ def set_step(inputs, step):
 | 
			
		|||
    define_inputs = f"(define-const test_inputs_step_n{step} Inputs ({mk_inputs_call}))\n"
 | 
			
		||||
 | 
			
		||||
    # Create the output definition by calling the gold_step function
 | 
			
		||||
    define_output = f"(define-const test_outputs_step_n{step} Outputs (gold_step #b0 test_inputs_step_n{step}))\n"
 | 
			
		||||
    define_output = f"(define-const test_outputs_step_n{step} Outputs (gold_step test_inputs_step_n{step}))\n"
 | 
			
		||||
    smt_commands = []
 | 
			
		||||
    smt_commands.append(define_inputs)
 | 
			
		||||
    smt_commands.append(define_output)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue