5#ifndef IDOL_IMPL_STRONGBRANCHING_H
6#define IDOL_IMPL_STRONGBRANCHING_H
8#include "VariableBranching.h"
9#include "NodeScoreFunction.h"
10#include "idol/mixed-integer/optimizers/branch-and-bound/branching-rules/impls/strong-branching/StrongBranchingPhase.h"
11#include "MostInfeasible.h"
12#include "idol/mixed-integer/optimizers/branch-and-bound/branching-rules/factories/StrongBranching.h"
14namespace idol::BranchingRules {
15 template<
class NodeInfoT>
class StrongBranching;
18template<
class NodeInfoT>
20 std::unique_ptr<VariableBranching<NodeInfoT>> m_inner_variable_branching_rule;
21 std::unique_ptr<NodeScoreFunction> m_score_function;
22 std::list<StrongBranchingPhase> m_phases;
24 std::vector<std::pair<Var, double>> sort_variables_by_score(
const std::list<std::pair<Var, double>>& t_scores);
26 std::list<Node<NodeInfoT>> make_nodes(
const std::list<NodeInfoT*>& t_node_infos,
const Node<NodeInfoT>& t_parent_node);
32 double compute_score(
double t_parent_objective, std::list<
Node<NodeInfoT>>& t_nodes);
35 std::list<Var> t_branching_candidates,
36 unsigned int t_max_n_variables,
38 const std::list<StrongBranchingPhase>& t_phases);
40 std::list<std::pair<Var, double>> scoring_function(
const std::list<Var> &t_var,
const Node<NodeInfoT> &t_node)
override;
43template<
class NodeInfoT>
46 std::list<Var> t_branching_candidates,
47 unsigned int t_max_n_variables,
49 const std::list<StrongBranchingPhase>& t_phases
52 m_inner_variable_branching_rule(new BranchingRules::
MostInfeasible<NodeInfoT>(t_parent, {})),
53 m_score_function(t_score_function->clone()),
58 m_phases.emplace_back(
61 std::numeric_limits<unsigned int>::max()
65template<
class NodeInfoT>
66std::list<std::pair<idol::Var, double>>
68 const Node<NodeInfoT> &t_node) {
70 std::list<std::pair<Var, double>> result;
72 auto& phase = current_phase(t_node);
73 const auto scores = m_inner_variable_branching_rule->scoring_function(t_variables, t_node);
74 const auto sorted_scores = sort_variables_by_score(scores);
75 const unsigned int n_nodes_to_solve = std::min<unsigned int>(phase.max_n_variables(), sorted_scores.size());
77 const double objective_value_parent_node = t_node.info().objective_value();
79 for (
unsigned int k = 0 ; k < n_nodes_to_solve ; ++k) {
81 const auto branching_candidate = sorted_scores[k].first;
83 auto node_infos = m_inner_variable_branching_rule->create_child_nodes_for_selected_variable(t_node, branching_candidate);
84 auto nodes = make_nodes(node_infos, t_node);
86 solve_nodes(phase, nodes);
88 result.emplace_back(branching_candidate, compute_score(objective_value_parent_node, nodes));
95template<
class NodeInfoT>
97 const std::list<std::pair<Var, double>> &t_scores) {
99 std::vector<std::pair<Var, double>> result;
100 result.reserve(t_scores.size());
102 std::copy(t_scores.begin(), t_scores.end(), std::back_inserter(result));
104 std::sort(result.begin(), result.end(), [](
const auto& t_a,
const auto& t_b) {
105 return t_a.second > t_b.second;
110template<
class NodeInfoT>
111std::list<idol::Node<NodeInfoT>>
113 const Node<NodeInfoT>& t_parent_node) {
115 std::list<idol::Node<NodeInfoT>> result;
117 const unsigned int id = t_parent_node.id();
119 for (
auto* info : t_node_infos) {
120 result.emplace_back(info,
id, t_parent_node);
126template<
class NodeInfoT>
129 auto& branch_and_bound = this->parent();
130 auto& optimizer =
const_cast<Optimizer&
>(branch_and_bound.relaxation().optimizer());
132 for (
auto& node : t_nodes) {
134 t_phase.type().build(optimizer);
136 branch_and_bound.solve(node, 0);
138 t_phase.type().clean(optimizer);
143template<
class NodeInfoT>
145 std::list<Node<NodeInfoT>>& t_nodes) {
147 if (t_nodes.size() != 2) {
148 throw Exception(
"Strong branching expected two nodes, got " + std::to_string(t_nodes.size()) +
".");
151 const auto& left_node_info = t_nodes.front().info();
152 const auto& right_node_info = t_nodes.back().info();
154 const double left_objective_value = left_node_info.has_objective_value() ? left_node_info.objective_value() : Inf;
155 const double right_objective_value = right_node_info.has_objective_value() ? right_node_info.objective_value() : Inf;
157 return m_score_function->operator()(
158 left_objective_value - t_parent_objective,
159 right_objective_value - t_parent_objective);
163template<
class NodeInfoT>
167 const unsigned int level = t_node.level();
169 for (
auto& phase : m_phases) {
170 if (level <= phase.max_depth()) {
175 throw Exception(
"Could not infer strong branching phase.");