22Core model implementation for compartmental models.
33"""
44
5- from typing import Any , Callable
5+ from typing import Any , Callable , Dict
66
7- import numpy as np
87from scipy .integrate import odeint
98
109
@@ -21,8 +20,8 @@ def compute_rates(self, state: dict[str, float]) -> dict[str, float]:
2120 """Compute transition rates for each compartment."""
2221 rates = {}
2322 for transition in self .transitions :
24- source = transition ['source ' ]
25- target = transition ['target ' ]
23+ source = transition ['from ' ]
24+ target = transition ['to ' ]
2625 rate = transition ['rate' ]
2726 # Handle rate expressions
2827 if isinstance (rate , str ):
@@ -45,89 +44,93 @@ def __init__(self, base_model: CompartmentalModel, num_patches: int, network_mat
4544 self .network = network_matrix
4645 self .all_compartments = [f"{ c } _{ i } " for i in range (num_patches ) for c in base_model .compartments ]
4746
48- def compute_force_of_infection (self , full_state : dict [str , float ]) -> list [float ]:
49- """Compute force of infection for each patch."""
47+ def get_patch_state (self , full_state : Dict [str , float ], patch_idx : int ) -> Dict [str , float ]:
48+ """Get state for a specific patch."""
49+ return {c : full_state [f"{ c } _{ patch_idx } " ] for c in self .base_model .compartments }
50+
51+ def get_patch_population (self , state : Dict [str , float ]) -> float :
52+ """Get total population for a patch."""
53+ return sum (state [c ] for c in self .base_model .compartments )
54+
55+ def compute_force_of_infection (self , full_state : dict [str , float ], infected_compartment : str = "I" ) -> list [float ]:
56+ """Compute force of infection for each patch.
57+
58+ Args:
59+ full_state: Current state of all compartments
60+ infected_compartment: Name of the compartment representing infected individuals
61+ """
5062 lambdas = []
5163 for i in range (self .num_patches ):
52- force = 0
53- for j in range (self .num_patches ):
54- infected_j = full_state [f"I_{ j } " ]
55- pop_j = sum (full_state [f"{ c } _{ j } " ] for c in self .base_model .compartments )
56- force += self .network [i ][j ] * infected_j / pop_j
64+ if self .num_patches == 1 :
65+ # Single patch case: local force of infection
66+ patch_state = self .get_patch_state (full_state , 0 )
67+ infected = patch_state [infected_compartment ]
68+ total_pop = self .get_patch_population (patch_state )
69+ force = self .base_model .parameters ['beta' ] * infected / total_pop
70+ else :
71+ # Multi-patch case: network-based force of infection
72+ force = 0
73+ for j in range (self .num_patches ):
74+ patch_state_j = self .get_patch_state (full_state , j )
75+ infected_j = patch_state_j [infected_compartment ]
76+ pop_j = self .get_patch_population (patch_state_j )
77+ force += self .network [i ][j ] * infected_j / pop_j
5778 lambdas .append (force )
5879 return lambdas
5980
81+ def compute_derivatives (self , state : dict [str , float ]) -> dict [str , float ]:
82+ """Compute derivatives for all compartments based on transitions."""
83+ derivatives = {c : 0.0 for c in self .all_compartments }
84+
85+ # Process each patch
86+ for i in range (self .num_patches ):
87+ # Get state for this patch
88+ patch_state = self .get_patch_state (state , i )
89+
90+ # Get rates for this patch
91+ rates = self .base_model .compute_rates (patch_state )
92+
93+ # Update derivatives based on rates
94+ for transition in self .base_model .transitions :
95+ source = transition ['from' ]
96+ target = transition ['to' ]
97+ rate = rates [f"{ source } _to_{ target } " ]
98+
99+ # Decrease source compartment
100+ derivatives [f"{ source } _{ i } " ] -= rate
101+ # Increase target compartment
102+ derivatives [f"{ target } _{ i } " ] += rate
103+
104+ return derivatives
105+
60106 def simulate_discrete (self , y0_dict : dict [str , float ], t_range : list [float ]) -> dict [str , list [float ]]:
61107 """Run discrete-time simulation."""
62108 state = y0_dict .copy ()
63109 history = {c : [state [c ]] for c in self .all_compartments }
110+
64111 for _ in t_range [1 :]:
65- new_state = state .copy ()
66- lambdas = self .compute_force_of_infection (state )
67- for i in range (self .num_patches ):
68- for c in self .base_model .compartments :
69- comp_key = f"{ c } _{ i } "
70- if c == 'S' :
71- new_state [comp_key ] -= lambdas [i ] * state [comp_key ]
72- elif c == 'I' :
73- new_state [comp_key ] += (
74- lambdas [i ] * state [f"S_{ i } " ] -
75- self .base_model .parameters ['gamma' ] * state [comp_key ]
76- )
77- elif c == 'R' :
78- new_state [comp_key ] += self .base_model .parameters ['gamma' ] * state [f"I_{ i } " ]
112+ derivatives = self .compute_derivatives (state )
113+ new_state = {
114+ c : state [c ] + derivatives [c ]
115+ for c in self .all_compartments
116+ }
79117 state = new_state
80118 for c in self .all_compartments :
81119 history [c ].append (state [c ])
120+
82121 return history
83122
84123 def simulate_ode (
85124 self , y0_dict : dict [str , float ], t_range : list [float ], integrator : Callable = odeint
86125 ) -> tuple [list [float ], dict [str , list [float ]]]:
87126 """Run ODE simulation."""
88- if (
89- hasattr (self , 'all_compartments' ) and
90- len (self .all_compartments ) > 0 and
91- all ('_' in c for c in self .all_compartments )
92- ):
93- y0 = [y0_dict [c ] for c in self .all_compartments ]
94-
95- def rhs (y , t ):
96- state = {c : y [i ] for i , c in enumerate (self .all_compartments )}
97- lambdas = self .compute_force_of_infection (state )
98- dydt = np .zeros_like (y )
99- for i in range (self .num_patches ):
100- for c in self .base_model .compartments :
101- idx = self .all_compartments .index (f"{ c } _{ i } " )
102- if c == 'S' :
103- dydt [idx ] = - lambdas [i ] * state [f"S_{ i } " ]
104- elif c == 'I' :
105- dydt [idx ] = (
106- lambdas [i ] * state [f"S_{ i } " ] -
107- self .base_model .parameters ['gamma' ] * state [f"I_{ i } " ]
108- )
109- elif c == 'R' :
110- dydt [idx ] = self .base_model .parameters ['gamma' ] * state [f"I_{ i } " ]
111- return dydt
112-
113- sol = integrator (rhs , y0 , t_range )
114- out = {c : sol [:, i ] for i , c in enumerate (self .all_compartments )}
115- return t_range , out
116-
117- y0 = [y0_dict [c ] for c in self .base_model .compartments ]
127+ y0 = [y0_dict [c ] for c in self .all_compartments ]
118128
119129 def rhs (y , t ):
120- state = {c : y [i ] for i , c in enumerate (self .base_model .compartments )}
121- rates = self .base_model .compute_rates (state )
122- dydt = np .zeros_like (y )
123- for i , c in enumerate (self .base_model .compartments ):
124- for transition in self .base_model .transitions :
125- if transition ['source' ] == c :
126- dydt [i ] -= rates [f"{ c } _to_{ transition ['target' ]} " ]
127- if transition ['target' ] == c :
128- dydt [i ] += rates [f"{ transition ['source' ]} _to_{ c } " ]
129- return dydt
130+ state = {c : y [i ] for i , c in enumerate (self .all_compartments )}
131+ derivatives = self .compute_derivatives (state )
132+ return [derivatives [c ] for c in self .all_compartments ]
130133
131134 sol = integrator (rhs , y0 , t_range )
132- out = {c : sol [:, i ] for i , c in enumerate (self .base_model . compartments )}
135+ out = {c : sol [:, i ] for i , c in enumerate (self .all_compartments )}
133136 return t_range , out
0 commit comments