Skip to content

Commit 9f4fb89

Browse files
committed
address review comments -
- make fork safety complete in the client - add shutdown mechanism to the integration - better test coverage - better docs on usage
1 parent 5357434 commit 9f4fb89

File tree

7 files changed

+576
-179
lines changed

7 files changed

+576
-179
lines changed

examples/celery_integration.py

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,35 @@
22
Celery integration example for the PostHog Python SDK.
33
44
Demonstrates how to use ``PosthogCeleryIntegration`` with:
5-
- producer-side instrumentation (publishing events and context propagation)
6-
- worker-side instrumentation via ``worker_process_init`` (prefork-safe)
5+
- producer-side and worker-side instrumentation (publishing events and context propagation)
76
- context propagation (distinct ID, session ID, tags) from producer to worker
87
- task lifecycle events (published, started, success, failure, retry)
98
- exception capture from failed tasks
109
- ``task_filter`` customization hook
1110
1211
Setup:
13-
1. Update POSTHOG_PROJECT_API_KEY and POSTHOG_HOST here with your credentials
14-
(environment variables won't work as it's better if Celery forks worker into
15-
separate process for the example to prove context propagation)
12+
1. Set ``POSTHOG_PROJECT_API_KEY`` and ``POSTHOG_HOST`` in your environment
1613
2. Install dependencies: pip install posthog celery redis
1714
3. Start Redis: redis-server
1815
4. Start the worker: celery -A examples.celery_integration worker --loglevel=info
1916
5. Run the producer: python -m examples.celery_integration
2017
"""
2118

19+
import os
2220
import time
2321
from typing import Any, Optional
2422

2523
from celery import Celery
2624
from celery.signals import worker_process_init, worker_process_shutdown
2725

2826
import posthog
29-
from posthog.client import Client
3027
from posthog.integrations.celery import PosthogCeleryIntegration
3128

3229

3330
# --- Configuration ---
3431

35-
POSTHOG_PROJECT_API_KEY = "phc_..."
36-
POSTHOG_HOST = "http://localhost:8000"
32+
POSTHOG_PROJECT_API_KEY = os.getenv("POSTHOG_PROJECT_API_KEY", "phc_...")
33+
POSTHOG_HOST = os.getenv("POSTHOG_HOST", "http://localhost:8000")
3734

3835
app = Celery(
3936
"examples.celery_integration",
@@ -43,11 +40,11 @@
4340

4441
# --- Integration wiring ---
4542

46-
def create_client() -> Client:
47-
return Client(
48-
project_api_key=POSTHOG_PROJECT_API_KEY,
49-
host=POSTHOG_HOST
50-
)
43+
def configure_posthog() -> None:
44+
posthog.api_key = POSTHOG_PROJECT_API_KEY
45+
posthog.host = POSTHOG_HOST
46+
posthog.enable_local_evaluation = False # to not require personal_api_key for this example
47+
posthog.setup()
5148

5249

5350
def task_filter(task_name: Optional[str], task_properties: dict[str, Any]) -> bool:
@@ -56,40 +53,42 @@ def task_filter(task_name: Optional[str], task_properties: dict[str, Any]) -> bo
5653
return True
5754

5855

59-
def create_integration(client: Client) -> PosthogCeleryIntegration:
56+
def create_integration() -> PosthogCeleryIntegration:
6057
return PosthogCeleryIntegration(
61-
client=client,
6258
capture_exceptions=True,
6359
capture_task_lifecycle_events=True,
6460
propagate_context=True,
6561
task_filter=task_filter,
6662
)
6763

68-
69-
# Worker process setup.
70-
# Celery's default prefork pool runs tasks in child processes, so initialize
71-
# PostHog per child using worker_process_init.
64+
configure_posthog()
65+
integration = create_integration()
66+
integration.instrument()
7267

7368

69+
# --- Worker process setup ---
70+
# Celery's default prefork pool runs tasks in child processes. This example
71+
# runs on a single host, so the inherited PostHog client and Celery
72+
# integration are fork-safe and do not need to be recreated in each child.
73+
# If workers run across multiple hosts, configure PostHog and instrument a
74+
# worker-local integration in worker_process_init.
7475
@worker_process_init.connect
7576
def on_worker_process_init(**kwargs) -> None:
76-
worker_posthog_client = create_client()
77-
worker_integration = create_integration(worker_posthog_client)
78-
worker_integration.instrument()
79-
80-
app._posthog_client = worker_posthog_client
81-
app._posthog_integration = worker_integration
77+
# global integration
78+
79+
# configure_posthog()
80+
# integration = create_integration()
81+
# integration.instrument()
82+
return
8283

8384

85+
# Use this signal to shutdown the integration and PostHog client
86+
# Calling shutdown() is important to flush any pending events
8487
@worker_process_shutdown.connect
8588
def on_worker_process_shutdown(**kwargs) -> None:
86-
worker_integration = getattr(app, "_posthog_integration", None)
87-
if worker_integration:
88-
worker_integration.uninstrument()
89+
integration.shutdown()
90+
posthog.shutdown()
8991

90-
worker_posthog_client = getattr(app, "_posthog_client", None)
91-
if worker_posthog_client:
92-
worker_posthog_client.shutdown()
9392

9493
# --- Example tasks ---
9594

@@ -98,8 +97,8 @@ def health_check() -> dict[str, str]:
9897
return {"status": "ok"}
9998

10099

101-
@app.task(bind=True, max_retries=3)
102-
def process_order(self, order_id: str) -> dict:
100+
@app.task(max_retries=3)
101+
def process_order(order_id: str) -> dict:
103102
"""A task that processes an order successfully."""
104103

105104
# simulate work
@@ -108,7 +107,7 @@ def process_order(self, order_id: str) -> dict:
108107
# Custom event inside the task - context tags propagated from the
109108
# producer (e.g. "source", "release") should appear on this event
110109
# and this should be attributed to the correct distinct ID and session.
111-
app._posthog_client.capture(
110+
posthog.capture(
112111
"celery example order processed",
113112
properties={"order_id": order_id, "amount": 99.99},
114113
)
@@ -136,17 +135,13 @@ def failing_task() -> None:
136135
# --- Producer code ---
137136

138137
if __name__ == "__main__":
139-
posthog_client = create_client()
140-
integration = create_integration(posthog_client)
141-
integration.instrument()
142-
143138
print("PostHog Celery Integration Example")
144139
print("=" * 40)
145140
print()
146141

147142
# Set up PostHog context before dispatching tasks.
148143
# The integration propagates this context to workers via task headers.
149-
with posthog.new_context(fresh=True, client=posthog_client):
144+
with posthog.new_context(fresh=True):
150145
posthog.identify_context("user-123")
151146
posthog.set_context_session("session-user-123-abc")
152147
posthog.tag("source", "celery_integration_example_script")
@@ -186,6 +181,5 @@ def failing_task() -> None:
186181
print("Tasks dispatched. Check your Celery worker logs and PostHog for events.")
187182
print()
188183

189-
posthog_client.flush()
190-
integration.uninstrument()
191-
posthog_client.shutdown()
184+
integration.shutdown()
185+
posthog.shutdown()

posthog/client.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
flags,
5858
get,
5959
remote_config,
60+
reset_sessions,
6061
)
6162
from posthog.types import (
6263
FeatureFlag,
@@ -245,6 +246,7 @@ def __init__(
245246
)
246247
self.poller = None
247248
self.distinct_ids_feature_flags_reported = SizeLimitedDict(MAX_DICT_SIZE, set)
249+
self.flag_fallback_cache_url = flag_fallback_cache_url
248250
self.flag_cache = self._initialize_flag_cache(flag_fallback_cache_url)
249251
self.flag_definition_version = 0
250252
self._flags_etag: Optional[str] = None
@@ -1098,42 +1100,56 @@ def _reinit_after_fork_weak(weak_self):
10981100
self._reinit_after_fork()
10991101

11001102
def _reinit_after_fork(self):
1101-
"""Reinitialize queue and consumer threads in a forked child process.
1103+
"""Reinitialize fork-unsafe client state in a forked child process.
11021104
11031105
Registered via os.register_at_fork(after_in_child=...) so it runs
11041106
exactly once in each child, before any user code, covering all code
11051107
paths (capture, flush, join, etc.).
11061108
11071109
Python threads do not survive fork() and queue.Queue internal locks
1108-
may be in an inconsistent state, so both are replaced.
1109-
Inherited queue items are intentionally discarded as they'll be
1110-
handled by the parent process's consumers.
1110+
may be in an inconsistent state, so the event queue, consumer threads
1111+
and other state are replaced. Inherited queue items are not retained
1112+
as they'll be handled by the parent process's consumers.
11111113
"""
1112-
if self.consumers is None:
1113-
return
1114+
if self.consumers:
1115+
self.queue = queue.Queue(self._max_queue_size)
1116+
1117+
new_consumers = []
1118+
for old in self.consumers:
1119+
consumer = Consumer(
1120+
self.queue,
1121+
old.api_key,
1122+
flush_at=old.flush_at,
1123+
host=old.host,
1124+
on_error=old.on_error,
1125+
flush_interval=old.flush_interval,
1126+
gzip=old.gzip,
1127+
retries=old.retries,
1128+
timeout=old.timeout,
1129+
historical_migration=old.historical_migration,
1130+
)
1131+
new_consumers.append(consumer)
1132+
1133+
if self.send:
1134+
consumer.start()
1135+
1136+
self.consumers = new_consumers
11141137

1115-
self.queue = queue.Queue(self._max_queue_size)
1116-
1117-
new_consumers = []
1118-
for old in self.consumers:
1119-
consumer = Consumer(
1120-
self.queue,
1121-
old.api_key,
1122-
flush_at=old.flush_at,
1123-
host=old.host,
1124-
on_error=old.on_error,
1125-
flush_interval=old.flush_interval,
1126-
gzip=old.gzip,
1127-
retries=old.retries,
1128-
timeout=old.timeout,
1129-
historical_migration=old.historical_migration,
1138+
if self.enable_local_evaluation:
1139+
self.poller = Poller(
1140+
interval=timedelta(seconds=self.poll_interval),
1141+
execute=self._load_feature_flags,
11301142
)
1131-
new_consumers.append(consumer)
1143+
self.poller.start()
1144+
else:
1145+
self.poller = None
11321146

1133-
if self.send:
1134-
consumer.start()
1147+
# If using Redis cache, we must reinitialize to get a fresh connection (fork-safe).
1148+
# If using Memory cache, we keep it as-is to benefit from the inherited warm cache.
1149+
if isinstance(self.flag_cache, RedisFlagCache):
1150+
self.flag_cache = self._initialize_flag_cache(self.flag_fallback_cache_url)
11351151

1136-
self.consumers = new_consumers
1152+
reset_sessions()
11371153

11381154
def _enqueue(self, msg, disable_geoip):
11391155
# type: (...) -> Optional[str]

0 commit comments

Comments
 (0)