try out z3 for 21

This commit is contained in:
Eric Mertens 2023-04-06 20:58:54 -07:00
parent d137b29608
commit 509672bc9e
3 changed files with 43 additions and 26 deletions

View File

@ -11,6 +11,8 @@
#include <doctest.h> #include <doctest.h>
#include <z3++.h>
#include <aocpp/Overloaded.hpp> #include <aocpp/Overloaded.hpp>
#include <aocpp/Startup.hpp> #include <aocpp/Startup.hpp>
@ -122,35 +124,49 @@ auto Part1(Input const& input) -> std::int64_t
auto Part2(Input const& input) -> std::int64_t auto Part2(Input const& input) -> std::int64_t
{ {
std::unordered_map<std::string, double> values; auto ctx = z3::context{};
std::unordered_map<std::string, Expr> exprs; auto solve_eqs = z3::tactic{ctx, "solve-eqs"};
auto smt = z3::tactic{ctx, "smt"};
auto solver = (solve_eqs & smt).mk_solver();
std::unordered_map<std::string, z3::expr> constants;
for (auto const& entry : input) { for (auto const& entry : input) {
std::visit(overloaded { constants.emplace(entry.lvalue, ctx.int_const(entry.lvalue.c_str()));
[&](Expr const& expr) { exprs.emplace(entry.lvalue, expr); },
[&](std::int64_t val) { values.emplace(entry.lvalue, val); }
}, entry.rvalue);
} }
exprs.at("root").op = Op::Sub; for (auto const& entry : input) {
if (entry.lvalue == "root") {
auto eval = [&](double humn) -> double { if (auto * const expr = std::get_if<Expr>(&entry.rvalue)) {
auto values_ = values; solver.add(constants.at(expr->lhs) == constants.at(expr->rhs));
values_.at("humn") = humn; } else {
return Eval(values_, exprs, "root"); throw std::runtime_error{"malformed root"};
}; }
} else if (entry.lvalue != "humn") {
auto x0 = 10.0; auto const rhs =
auto x1 = 20.0; std::visit(overloaded {
auto fx0 = eval(x0); [&ctx](std::int64_t literal) {
while (std::abs(x0 - x1) > 0.01) { return ctx.int_val(literal);
auto const fx1 = eval(x1); },
auto const x2 = x1 - fx1 * (x1 - x0) / (fx1 - fx0); [&ctx, &constants, &solver](Expr const& expr) {
x0 = x1; auto const l = constants.at(expr.lhs);
x1 = x2; auto const r = constants.at(expr.rhs);
fx0 = fx1; switch (expr.op) {
case Op::Add: return l + r;
case Op::Sub: return l - r;
case Op::Mul: return l * r;
case Op::Div:
solver.add(l % r == 0);
return l / r;
}
}
}, entry.rvalue);
solver.add(constants.at(entry.lvalue) == rhs);
}
} }
if (solver.check() != z3::sat) {
return std::round(x1); throw std::runtime_error{"no solution to part 2"};
}
return solver.get_model().eval(constants.at("humn")).as_int64();
} }
} // namespace } // namespace

View File

@ -47,7 +47,7 @@ add_executable(2022_20 20.cpp)
target_link_libraries(2022_20 aocpp Boost::headers) target_link_libraries(2022_20 aocpp Boost::headers)
add_executable(2022_21 21.cpp) add_executable(2022_21 21.cpp)
target_link_libraries(2022_21 aocpp Boost::headers) target_link_libraries(2022_21 aocpp Boost::headers PkgConfig::Z3)
add_executable(2022_25 25.cpp) add_executable(2022_25 25.cpp)
target_link_libraries(2022_25 aocpp) target_link_libraries(2022_25 aocpp)

View File

@ -18,6 +18,7 @@ endif()
find_package(PkgConfig) find_package(PkgConfig)
pkg_check_modules(GMP REQUIRED IMPORTED_TARGET gmpxx) pkg_check_modules(GMP REQUIRED IMPORTED_TARGET gmpxx)
pkg_check_modules(Z3 REQUIRED IMPORTED_TARGET z3)
find_package(Boost REQUIRED) find_package(Boost REQUIRED)
add_subdirectory(lib) add_subdirectory(lib)