Skip to content

Commit 6a2bb83

Browse files
committed
modified cli inputs and config
1 parent c771da0 commit 6a2bb83

5 files changed

Lines changed: 83 additions & 218 deletions

File tree

configs/sample-sir-ode.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@ NetworkFile: data/networks/sample-network-static.csv
44
SeedFile: data/seeds/sample-sir-ode-patchA-2.csv
55
Logging: False
66

7+
# Model configuration
8+
ModelName: sample-sir-ode
9+
710
# Simulation parameters
8-
Beta: 0.3
9-
Gamma: 0.1
1011
TMax: 50
1112
Tolerance: 1e-8
1213
MaxIter: 10000
1314
StartDate: 2020-01-01
1415
EndDate: 2022-12-31
1516
OutputDir: output/sample-sir-ode
17+
compartments: ["S", "I", "R"]
18+
Parameters: {"beta": 0.5, "gamma": 0.1}
19+
Transitions: [{"from": "S", "to": "I", "rate": "beta * S * I / (S + I + R)"}, {"from": "I", "to": "R", "rate": "gamma * I"}]
1620

src/patchsim/cli.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import argparse
22
import logging
3-
import os
4-
from datetime import datetime
53

6-
from patchsim.models.sample_sir_ode import run_simulation
4+
from patchsim.core.simulation import load_config, run_simulation, setup_simulation
75

86

97
def main():
@@ -18,9 +16,6 @@ def main():
1816
Examples:
1917
# Run simulation with sample SIR model
2018
patchsim --config configs/sample-sir-ode.yaml
21-
22-
# Run simulation with custom model
23-
patchsim --config path/to/your/config.yaml
2419
"""
2520
)
2621
parser.add_argument(
@@ -31,25 +26,13 @@ def main():
3126
)
3227
args = parser.parse_args()
3328

34-
# Create output directories
35-
os.makedirs("output/logs", exist_ok=True)
36-
37-
# Set up logging with timestamp
38-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
39-
log_file = f"output/logs/cli_{timestamp}.log"
40-
41-
logging.basicConfig(
42-
level=logging.INFO,
43-
format="%(asctime)s - %(levelname)s - %(message)s",
44-
handlers=[
45-
logging.FileHandler(log_file),
46-
logging.StreamHandler(),
47-
],
48-
)
49-
5029
try:
5130
logging.info("Starting PatchSim simulation...")
52-
run_simulation(args.config)
31+
config = load_config(args.config)
32+
33+
# Set up simulation
34+
net, y0, patches, num_patches = setup_simulation(config)
35+
run_simulation(config, config['ModelName'], net, y0, patches, num_patches)
5336
logging.info("Simulation completed successfully.")
5437
except Exception as e:
5538
logging.error(f"Simulation failed: {e}")

src/patchsim/core/model.py

Lines changed: 69 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
Core model implementation for compartmental models.
33
"""
44

5-
from typing import Any, Callable
5+
from typing import Any, Callable, Dict
66

7-
import numpy as np
87
from scipy.integrate import odeint
98

109

@@ -21,8 +20,8 @@ def compute_rates(self, state: dict[str, float]) -> dict[str, float]:
2120
"""Compute transition rates for each compartment."""
2221
rates = {}
2322
for transition in self.transitions:
24-
source = transition['source']
25-
target = transition['target']
23+
source = transition['from']
24+
target = transition['to']
2625
rate = transition['rate']
2726
# Handle rate expressions
2827
if isinstance(rate, str):
@@ -45,89 +44,93 @@ def __init__(self, base_model: CompartmentalModel, num_patches: int, network_mat
4544
self.network = network_matrix
4645
self.all_compartments = [f"{c}_{i}" for i in range(num_patches) for c in base_model.compartments]
4746

48-
def compute_force_of_infection(self, full_state: dict[str, float]) -> list[float]:
49-
"""Compute force of infection for each patch."""
47+
def get_patch_state(self, full_state: Dict[str, float], patch_idx: int) -> Dict[str, float]:
48+
"""Get state for a specific patch."""
49+
return {c: full_state[f"{c}_{patch_idx}"] for c in self.base_model.compartments}
50+
51+
def get_patch_population(self, state: Dict[str, float]) -> float:
52+
"""Get total population for a patch."""
53+
return sum(state[c] for c in self.base_model.compartments)
54+
55+
def compute_force_of_infection(self, full_state: dict[str, float], infected_compartment: str = "I") -> list[float]:
56+
"""Compute force of infection for each patch.
57+
58+
Args:
59+
full_state: Current state of all compartments
60+
infected_compartment: Name of the compartment representing infected individuals
61+
"""
5062
lambdas = []
5163
for i in range(self.num_patches):
52-
force = 0
53-
for j in range(self.num_patches):
54-
infected_j = full_state[f"I_{j}"]
55-
pop_j = sum(full_state[f"{c}_{j}"] for c in self.base_model.compartments)
56-
force += self.network[i][j] * infected_j / pop_j
64+
if self.num_patches == 1:
65+
# Single patch case: local force of infection
66+
patch_state = self.get_patch_state(full_state, 0)
67+
infected = patch_state[infected_compartment]
68+
total_pop = self.get_patch_population(patch_state)
69+
force = self.base_model.parameters['beta'] * infected / total_pop
70+
else:
71+
# Multi-patch case: network-based force of infection
72+
force = 0
73+
for j in range(self.num_patches):
74+
patch_state_j = self.get_patch_state(full_state, j)
75+
infected_j = patch_state_j[infected_compartment]
76+
pop_j = self.get_patch_population(patch_state_j)
77+
force += self.network[i][j] * infected_j / pop_j
5778
lambdas.append(force)
5879
return lambdas
5980

81+
def compute_derivatives(self, state: dict[str, float]) -> dict[str, float]:
82+
"""Compute derivatives for all compartments based on transitions."""
83+
derivatives = {c: 0.0 for c in self.all_compartments}
84+
85+
# Process each patch
86+
for i in range(self.num_patches):
87+
# Get state for this patch
88+
patch_state = self.get_patch_state(state, i)
89+
90+
# Get rates for this patch
91+
rates = self.base_model.compute_rates(patch_state)
92+
93+
# Update derivatives based on rates
94+
for transition in self.base_model.transitions:
95+
source = transition['from']
96+
target = transition['to']
97+
rate = rates[f"{source}_to_{target}"]
98+
99+
# Decrease source compartment
100+
derivatives[f"{source}_{i}"] -= rate
101+
# Increase target compartment
102+
derivatives[f"{target}_{i}"] += rate
103+
104+
return derivatives
105+
60106
def simulate_discrete(self, y0_dict: dict[str, float], t_range: list[float]) -> dict[str, list[float]]:
61107
"""Run discrete-time simulation."""
62108
state = y0_dict.copy()
63109
history = {c: [state[c]] for c in self.all_compartments}
110+
64111
for _ in t_range[1:]:
65-
new_state = state.copy()
66-
lambdas = self.compute_force_of_infection(state)
67-
for i in range(self.num_patches):
68-
for c in self.base_model.compartments:
69-
comp_key = f"{c}_{i}"
70-
if c == 'S':
71-
new_state[comp_key] -= lambdas[i] * state[comp_key]
72-
elif c == 'I':
73-
new_state[comp_key] += (
74-
lambdas[i] * state[f"S_{i}"] -
75-
self.base_model.parameters['gamma'] * state[comp_key]
76-
)
77-
elif c == 'R':
78-
new_state[comp_key] += self.base_model.parameters['gamma'] * state[f"I_{i}"]
112+
derivatives = self.compute_derivatives(state)
113+
new_state = {
114+
c: state[c] + derivatives[c]
115+
for c in self.all_compartments
116+
}
79117
state = new_state
80118
for c in self.all_compartments:
81119
history[c].append(state[c])
120+
82121
return history
83122

84123
def simulate_ode(
85124
self, y0_dict: dict[str, float], t_range: list[float], integrator: Callable = odeint
86125
) -> tuple[list[float], dict[str, list[float]]]:
87126
"""Run ODE simulation."""
88-
if (
89-
hasattr(self, 'all_compartments') and
90-
len(self.all_compartments) > 0 and
91-
all('_' in c for c in self.all_compartments)
92-
):
93-
y0 = [y0_dict[c] for c in self.all_compartments]
94-
95-
def rhs(y, t):
96-
state = {c: y[i] for i, c in enumerate(self.all_compartments)}
97-
lambdas = self.compute_force_of_infection(state)
98-
dydt = np.zeros_like(y)
99-
for i in range(self.num_patches):
100-
for c in self.base_model.compartments:
101-
idx = self.all_compartments.index(f"{c}_{i}")
102-
if c == 'S':
103-
dydt[idx] = -lambdas[i] * state[f"S_{i}"]
104-
elif c == 'I':
105-
dydt[idx] = (
106-
lambdas[i] * state[f"S_{i}"] -
107-
self.base_model.parameters['gamma'] * state[f"I_{i}"]
108-
)
109-
elif c == 'R':
110-
dydt[idx] = self.base_model.parameters['gamma'] * state[f"I_{i}"]
111-
return dydt
112-
113-
sol = integrator(rhs, y0, t_range)
114-
out = {c: sol[:, i] for i, c in enumerate(self.all_compartments)}
115-
return t_range, out
116-
117-
y0 = [y0_dict[c] for c in self.base_model.compartments]
127+
y0 = [y0_dict[c] for c in self.all_compartments]
118128

119129
def rhs(y, t):
120-
state = {c: y[i] for i, c in enumerate(self.base_model.compartments)}
121-
rates = self.base_model.compute_rates(state)
122-
dydt = np.zeros_like(y)
123-
for i, c in enumerate(self.base_model.compartments):
124-
for transition in self.base_model.transitions:
125-
if transition['source'] == c:
126-
dydt[i] -= rates[f"{c}_to_{transition['target']}"]
127-
if transition['target'] == c:
128-
dydt[i] += rates[f"{transition['source']}_to_{c}"]
129-
return dydt
130+
state = {c: y[i] for i, c in enumerate(self.all_compartments)}
131+
derivatives = self.compute_derivatives(state)
132+
return [derivatives[c] for c in self.all_compartments]
130133

131134
sol = integrator(rhs, y0, t_range)
132-
out = {c: sol[:, i] for i, c in enumerate(self.base_model.compartments)}
135+
out = {c: sol[:, i] for i, c in enumerate(self.all_compartments)}
133136
return t_range, out

src/patchsim/models/sample_sir_ode.py

Lines changed: 0 additions & 125 deletions
This file was deleted.

src/patchsim/utils/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
def read_config(config_path: str) -> dict[str, str]:
1818
"""Read and parse a YAML configuration file.
19-
19+
2020
Args:
2121
config_path: Path to the YAML configuration file
22-
22+
2323
Returns:
2424
Dictionary containing configuration parameters
2525
"""

0 commit comments

Comments
 (0)