try out z3 for 21

This commit is contained in:
Eric Mertens 2023-04-06 20:58:54 -07:00
parent d137b29608
commit 509672bc9e
3 changed files with 43 additions and 26 deletions

View File

@ -11,6 +11,8 @@
#include <doctest.h>
#include <z3++.h>
#include <aocpp/Overloaded.hpp>
#include <aocpp/Startup.hpp>
@ -122,35 +124,49 @@ auto Part1(Input const& input) -> std::int64_t
auto Part2(Input const& input) -> std::int64_t
{
std::unordered_map<std::string, double> values;
std::unordered_map<std::string, Expr> exprs;
auto ctx = z3::context{};
auto solve_eqs = z3::tactic{ctx, "solve-eqs"};
auto smt = z3::tactic{ctx, "smt"};
auto solver = (solve_eqs & smt).mk_solver();
std::unordered_map<std::string, z3::expr> constants;
for (auto const& entry : input) {
std::visit(overloaded {
[&](Expr const& expr) { exprs.emplace(entry.lvalue, expr); },
[&](std::int64_t val) { values.emplace(entry.lvalue, val); }
}, entry.rvalue);
constants.emplace(entry.lvalue, ctx.int_const(entry.lvalue.c_str()));
}
exprs.at("root").op = Op::Sub;
auto eval = [&](double humn) -> double {
auto values_ = values;
values_.at("humn") = humn;
return Eval(values_, exprs, "root");
};
auto x0 = 10.0;
auto x1 = 20.0;
auto fx0 = eval(x0);
while (std::abs(x0 - x1) > 0.01) {
auto const fx1 = eval(x1);
auto const x2 = x1 - fx1 * (x1 - x0) / (fx1 - fx0);
x0 = x1;
x1 = x2;
fx0 = fx1;
for (auto const& entry : input) {
if (entry.lvalue == "root") {
if (auto * const expr = std::get_if<Expr>(&entry.rvalue)) {
solver.add(constants.at(expr->lhs) == constants.at(expr->rhs));
} else {
throw std::runtime_error{"malformed root"};
}
} else if (entry.lvalue != "humn") {
auto const rhs =
std::visit(overloaded {
[&ctx](std::int64_t literal) {
return ctx.int_val(literal);
},
[&ctx, &constants, &solver](Expr const& expr) {
auto const l = constants.at(expr.lhs);
auto const r = constants.at(expr.rhs);
switch (expr.op) {
case Op::Add: return l + r;
case Op::Sub: return l - r;
case Op::Mul: return l * r;
case Op::Div:
solver.add(l % r == 0);
return l / r;
}
}
}, entry.rvalue);
solver.add(constants.at(entry.lvalue) == rhs);
}
}
return std::round(x1);
if (solver.check() != z3::sat) {
throw std::runtime_error{"no solution to part 2"};
}
return solver.get_model().eval(constants.at("humn")).as_int64();
}
} // namespace

View File

@ -47,7 +47,7 @@ add_executable(2022_20 20.cpp)
target_link_libraries(2022_20 aocpp Boost::headers)
add_executable(2022_21 21.cpp)
target_link_libraries(2022_21 aocpp Boost::headers)
target_link_libraries(2022_21 aocpp Boost::headers PkgConfig::Z3)
add_executable(2022_25 25.cpp)
target_link_libraries(2022_25 aocpp)

View File

@ -18,6 +18,7 @@ endif()
find_package(PkgConfig)
pkg_check_modules(GMP REQUIRED IMPORTED_TARGET gmpxx)
pkg_check_modules(Z3 REQUIRED IMPORTED_TARGET z3)
find_package(Boost REQUIRED)
add_subdirectory(lib)