Skip to content

Commit 2ef70b0

Browse files
committed
Fix simplex implementation: maxiter error + pivot optimization
1 parent 00bfa53 commit 2ef70b0

1 file changed

Lines changed: 51 additions & 42 deletions

File tree

linear_programming/simplex.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
1-
from typing import Any
1+
from typing import Dict
22
import numpy as np
33

44

55
class Tableau:
6-
"""Operate on simplex tableaus"""
6+
"""
7+
Simplex algorithm implementation using tableau form.
8+
9+
Solves linear programming problems of the form:
10+
maximize c^T x subject to Ax <= b and x >= 0
11+
"""
712

813
maxiter = 100
914

1015
def __init__(
1116
self, tableau: np.ndarray, n_vars: int, n_artificial_vars: int
1217
) -> None:
18+
tableau = tableau.astype("float64")
19+
1320
if tableau.dtype != "float64":
1421
raise TypeError("Tableau must have type float64")
1522

1623
if not (tableau[:, -1] >= 0).all():
1724
raise ValueError("RHS must be > 0")
1825

1926
if n_vars < 2 or n_artificial_vars < 0:
20-
raise ValueError(
21-
"number of (artificial) variables must be a natural number"
22-
)
27+
raise ValueError("Invalid number of variables")
2328

2429
self.tableau = tableau
2530
self.n_rows, n_cols = tableau.shape
@@ -40,20 +45,22 @@ def __init__(
4045

4146
def generate_col_titles(self) -> list[str]:
4247
string_starts = ["x", "s"]
43-
titles = []
48+
sizes = [self.n_vars, self.n_slack]
49+
50+
titles: list[str] = []
4451

4552
for i in range(2):
46-
for j in range((self.n_vars, self.n_slack)[i]):
53+
for j in range(sizes[i]):
4754
titles.append(string_starts[i] + str(j + 1))
4855

4956
titles.append("RHS")
5057
return titles
5158

52-
def find_pivot(self) -> tuple[Any, Any]:
59+
def find_pivot(self) -> tuple[int, int]:
5360
objective = self.objectives[-1]
5461

5562
sign = (objective == "min") - (objective == "max")
56-
col_idx = np.argmax(sign * self.tableau[0, :-1])
63+
col_idx = int(np.argmax(sign * self.tableau[0, :-1]))
5764

5865
if sign * self.tableau[0, col_idx] <= 0:
5966
self.stop_iter = True
@@ -65,27 +72,24 @@ def find_pivot(self) -> tuple[Any, Any]:
6572
divisor = self.tableau[s, col_idx]
6673

6774
nans = np.full(self.n_rows - self.n_stages, np.nan)
68-
quotients = np.divide(dividend, divisor, out=nans, where=divisor > 0)
6975

70-
row_idx = np.nanargmin(quotients) + self.n_stages
76+
quotients = np.divide(
77+
dividend, divisor, out=nans, where=divisor > 0
78+
)
79+
80+
row_idx = int(np.nanargmin(quotients) + self.n_stages)
7181
return row_idx, col_idx
7282

73-
# 🔥 OPTIMIZED PIVOT (major speed improvement)
7483
def pivot(self, row_idx: int, col_idx: int) -> np.ndarray:
7584
tableau = self.tableau
7685

7786
piv_row = tableau[row_idx].copy()
78-
piv_val = piv_row[col_idx]
87+
piv_row /= piv_row[col_idx]
7988

80-
# normalize pivot row
81-
piv_row /= piv_val
82-
83-
# vectorized elimination (FAST)
8489
tableau -= tableau[:, col_idx][:, None] * piv_row
85-
8690
tableau[row_idx] = piv_row
87-
self.tableau = tableau
8891

92+
self.tableau = tableau
8993
return tableau
9094

9195
def change_stage(self) -> np.ndarray:
@@ -106,12 +110,21 @@ def change_stage(self) -> np.ndarray:
106110

107111
return self.tableau
108112

109-
def run_simplex(self) -> dict[Any, Any]:
110-
"""Run simplex algorithm until optimal solution is found."""
113+
def run_simplex(self) -> Dict[str, float]:
114+
"""
115+
Run simplex algorithm until optimal solution is found.
111116
112-
maxiter = Tableau.maxiter
117+
>>> t = Tableau(np.array([
118+
... [-1, -1, 0, 0, 0],
119+
... [1, 3, 1, 0, 4],
120+
... [3, 1, 0, 1, 4]
121+
... ], dtype="float64"), 2, 0)
122+
>>> result = t.run_simplex()
123+
>>> result["P"]
124+
2.0
125+
"""
113126

114-
for iteration in range(maxiter):
127+
for iteration in range(Tableau.maxiter):
115128

116129
if not self.objectives:
117130
return self.interpret_tableau()
@@ -123,32 +136,28 @@ def run_simplex(self) -> dict[Any, Any]:
123136
else:
124137
self.tableau = self.pivot(row_idx, col_idx)
125138

126-
# FIXED: no silent failure anymore
127139
raise ValueError(
128-
"Simplex algorithm failed to converge.\n"
129-
f"- Iterations performed: {maxiter}\n"
130-
f"- Remaining objectives: {self.objectives}\n"
131-
"Possible causes:\n"
132-
"- Cycling (degeneracy)\n"
133-
"- Unbounded solution\n"
134-
"- Ill-conditioned input"
140+
"Simplex did not converge.\n"
141+
f"- Iterations: {Tableau.maxiter}\n"
142+
f"- Remaining objectives: {self.objectives}"
135143
)
136144

137-
def interpret_tableau(self) -> dict[str, float]:
138-
output_dict = {"P": abs(self.tableau[0, -1])}
145+
def interpret_tableau(self) -> Dict[str, float]:
146+
output: Dict[str, float] = {
147+
"P": float(abs(self.tableau[0, -1]))
148+
}
139149

140150
for i in range(self.n_vars):
141-
nonzero = np.nonzero(self.tableau[:, i])
142-
if len(nonzero[0]) == 0:
143-
continue
144-
145-
r = nonzero[0][0]
146-
val = self.tableau[r, i]
151+
nonzero = np.nonzero(self.tableau[:, i])[0]
147152

148-
if len(nonzero[0]) == 1 and val == 1:
149-
output_dict[self.col_titles[i]] = self.tableau[r, -1]
153+
if len(nonzero) == 1:
154+
row = nonzero[0]
155+
if self.tableau[row, i] == 1:
156+
output[self.col_titles[i]] = float(
157+
self.tableau[row, -1]
158+
)
150159

151-
return output_dict
160+
return output
152161

153162

154163
if __name__ == "__main__":

0 commit comments

Comments
 (0)