Loading...
Searching...
No Matches
VariableBranching.h
1//
2// Created by henri on 16.10.23.
3//
4
5#ifndef IDOL_VARIABLEBRANCHING_BRANCHINGRULE_H
6#define IDOL_VARIABLEBRANCHING_BRANCHINGRULE_H
7
8#include <list>
9#include "BranchingRule.h"
10#include "idol/mixed-integer/modeling/variables/Var.h"
11#include "idol/general/utils/Point.h"
12
13namespace idol::BranchingRules {
14 template<class>
16}
17
18template<class NodeInfoT>
19class idol::BranchingRules::VariableBranching : public BranchingRule<NodeInfoT> {
20 std::list<Var> m_branching_candidates;
21public:
22
23 virtual bool is_valid(const Node<NodeInfoT> &t_node) {
24
25 const auto& primal_solution = t_node.info().primal_solution();
26 const double tol_integer = this->parent().get_tol_integer();
27
28 for (const auto& var : m_branching_candidates) {
29 if (const double value = primal_solution.get(var) ; !is_integer(value, tol_integer)) {
30 return false;
31 }
32 }
33
34 return true;
35 }
36
37 virtual std::list<std::pair<Var, double>> scoring_function(const std::list<Var>& t_variables, const Node<NodeInfoT> &t_node) = 0;
38
39 virtual std::list<NodeInfoT *> create_child_nodes_for_selected_variable(const Node<NodeInfoT> &t_node, const Var& t_var) {
40
41 const auto& primal_solution = t_node.info().primal_solution();
42 const double value = primal_solution.get(t_var);
43 const double lb = std::ceil(value);
44 const double ub = std::floor(value);
45
46 auto* n1 = t_node.info().create_child();
47 n1->add_branching_variable(t_var, GreaterOrEqual, lb);
48
49 auto* n2 = t_node.info().create_child();
50 n2->add_branching_variable(t_var, LessOrEqual, ub);
51
52 return { n1, n2 };
53 }
54
55 virtual std::list<NodeInfoT *> create_child_nodes(const Node<NodeInfoT> &t_node) {
56
57 const auto& primal_solution = t_node.info().primal_solution();
58
59 auto invalid_variables = get_invalid_variables(primal_solution);
60
61 if (invalid_variables.empty()) {
62 return {};
63 }
64
65 Var selected_variable = invalid_variables.front();
66
67 if (invalid_variables.size() > 1) {
68 auto scores = scoring_function(invalid_variables, t_node);
69 selected_variable = get_argmax_score(scores);
70 }
71
72 return create_child_nodes_for_selected_variable(t_node, selected_variable);
73 }
74
75 [[nodiscard]] const std::list<Var>& branching_candidates() const { return m_branching_candidates; }
76
77 VariableBranching(const Optimizers::BranchAndBound<NodeInfoT>& t_parent, std::list<Var> t_branching_candidates)
78 : BranchingRule<NodeInfoT>(t_parent),
79 m_branching_candidates(std::move(t_branching_candidates)) {
80
81 }
82
83protected:
84 std::list<Var> get_invalid_variables(const PrimalPoint& t_primal_solution) {
85
86 std::list<Var> result;
87 const double tol_integer = this->parent().get_tol_integer();
88
89 for (const auto& var : m_branching_candidates) {
90 if (const double value = t_primal_solution.get(var) ; !is_integer(value, tol_integer)) {
91 result.emplace_back(var);
92 }
93 }
94
95 return result;
96 }
97
98 Var get_argmax_score(const std::list<std::pair<Var, double>>& t_scores) {
99
100 if (t_scores.empty()) {
101 throw Exception("VariableScoringFunction returned an empty list.");
102 }
103
104 double max = std::numeric_limits<double>::lowest();
105 std::optional<Var> argmax;
106
107 for (const auto& [var, score] : t_scores) {
108 if (max < score) {
109 max = score;
110 argmax = var;
111 }
112 }
113
114 if (!argmax.has_value()) {
115 throw Exception("Could not compute argmax of scores while searching for branching variable.");
116 }
117
118 return argmax.value();
119 }
120
121};
122
123#endif //IDOL_VARIABLEBRANCHING_BRANCHINGRULE_H