|
1 | 1 | import os |
2 | 2 | import json |
3 | | -from urllib.request import urlopen |
| 3 | +import aiohttp |
| 4 | +import asyncio |
| 5 | +import logging |
4 | 6 |
|
5 | 7 | """ |
6 | 8 | Prompt (aka context) tokens are based on number of words + other chars (eg spaces and punctuation) in input. |
|
20 | 22 | # Each completion token costs __ USD per token. |
21 | 23 | # Max prompt limit of each model is __ tokens. |
22 | 24 |
|
23 | | -# Fetch the latest prices using urllib.request |
24 | 25 | PRICES_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" |
25 | 26 |
|
| 27 | + |
| 28 | +async def fetch_costs(): |
| 29 | + """Fetch the latest token costs from the LiteLLM cost tracker asynchronously. |
| 30 | + Returns: |
| 31 | + dict: The token costs for each model. |
| 32 | + Raises: |
| 33 | + Exception: If the request fails. |
| 34 | + """ |
| 35 | + async with aiohttp.ClientSession() as session: |
| 36 | + async with session.get(PRICES_URL) as response: |
| 37 | + if response.status == 200: |
| 38 | + return await response.json(content_type=None) |
| 39 | + else: |
| 40 | + raise Exception(f"Failed to fetch token costs, status code: {response.status}") |
| 41 | + |
| 42 | + |
| 43 | +async def update_token_costs(): |
| 44 | + """Update the TOKEN_COSTS dictionary with the latest costs from the LiteLLM cost tracker asynchronously.""" |
| 45 | + global TOKEN_COSTS |
| 46 | + try: |
| 47 | + TOKEN_COSTS = await fetch_costs() |
| 48 | + print("TOKEN_COSTS updated successfully.") |
| 49 | + except Exception as e: |
| 50 | + logging.error(f"Failed to update TOKEN_COSTS: {e}") |
| 51 | + |
| 52 | +with open(os.path.join(os.path.dirname(__file__), "model_prices.json"), "r") as f: |
| 53 | + TOKEN_COSTS_STATIC = json.load(f) |
| 54 | + |
| 55 | + |
| 56 | +# Ensure TOKEN_COSTS is up to date when the module is loaded |
26 | 57 | try: |
27 | | - with urlopen(PRICES_URL) as response: |
28 | | - if response.status == 200: |
29 | | - TOKEN_COSTS = json.loads(response.read()) |
30 | | - else: |
31 | | - raise Exception("Failed to fetch token costs, status code: {}".format(response.status)) |
| 58 | + asyncio.run(update_token_costs()) |
32 | 59 | except Exception: |
33 | | - # If fetching fails, use the local model_prices.json as a fallback |
34 | | - with open(os.path.join(os.path.dirname(__file__), "model_prices.json"), "r") as f: |
35 | | - TOKEN_COSTS = json.load(f) |
| 60 | + logging.error('Failed to update token costs. Using static costs.') |
| 61 | + TOKEN_COSTS = TOKEN_COSTS_STATIC |
0 commit comments