Skip to content

Commit 11eca7e

Browse files
committed
Enable MPS by default on macOS and add device override
1 parent 014b9bf commit 11eca7e

11 files changed

Lines changed: 224 additions & 42 deletions

File tree

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
- os: macos-15-intel
3030
target: x86_64-apple-darwin
3131
allow_failure: false
32-
run_real_smoke: false
32+
run_real_smoke: true
3333
- os: macos-14
3434
target: aarch64-apple-darwin
3535
allow_failure: false

README.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,33 @@ The initial public model identifier is `bge-m3`.
1717
- Upstream model id: `BAAI/bge-m3`
1818
- Backend: `sentence-transformers`
1919
- Device:
20-
- Apple Silicon macOS: `mps` when available, otherwise CPU
20+
- macOS: `auto` by default, which selects `mps` when available, otherwise CPU
2121
- all other current targets: CPU
2222
- Provisioning: first-run download into a local cache directory
2323

2424
The command and HTTP layers are written against an internal backend registry so additional models or inference backends can be added later without changing the user-facing contracts.
2525

2626
## Acceleration support
2727

28-
Current hardware acceleration support is intentionally limited in `v0.1.2`:
28+
Current hardware acceleration support in `v0.1.3`:
2929

3030
- `aarch64-apple-darwin`:
31+
- defaults to `auto`
3132
- uses Apple Metal Performance Shaders (`mps`) automatically when available
3233
- falls back to CPU if MPS is unavailable
3334
- `x86_64-apple-darwin`:
34-
- CPU only
35+
- defaults to `auto`
36+
- uses Apple Metal Performance Shaders (`mps`) automatically when available
37+
- falls back to CPU if MPS is unavailable
3538
- `x86_64-unknown-linux-gnu`:
3639
- CPU only
3740
- `aarch64-unknown-linux-gnu`:
3841
- CPU only
3942
- `x86_64-pc-windows-msvc`:
4043
- CPU only
4144

45+
Explicit device override is available on embedding, server, and daemon commands via `--device auto|cpu|mps`.
46+
4247
The current release does not expose CUDA, ROCm, DirectML, or Intel GPU acceleration paths yet.
4348

4449
## Requirements
@@ -86,7 +91,7 @@ Example response:
8691
"embeddings": [[0.123, -0.456, 0.789]],
8792
"runtime": {
8893
"name": "bitloops-embeddings",
89-
"version": "0.1.2"
94+
"version": "0.1.3"
9095
}
9196
}
9297
```
@@ -100,6 +105,13 @@ bitloops-embeddings embed \
100105
--output ./embedding.json
101106
```
102107

108+
Force CPU or request MPS explicitly:
109+
110+
```bash
111+
bitloops-embeddings embed --model bge-m3 --input "Hello World" --device cpu
112+
bitloops-embeddings serve --model bge-m3 --device mps
113+
```
114+
103115
Inspect model metadata without loading the model:
104116

105117
```bash
@@ -160,7 +172,7 @@ Response shape:
160172
"embeddings": [[0.123, -0.456, 0.789]],
161173
"runtime": {
162174
"name": "bitloops-embeddings",
163-
"version": "0.1.2"
175+
"version": "0.1.3"
164176
}
165177
}
166178
```
@@ -258,6 +270,7 @@ Run the real-model smoke test against an installed console script or packaged ex
258270

259271
```bash
260272
python scripts/real_backend_smoke.py --binary bitloops-embeddings
273+
python scripts/real_backend_smoke.py --binary bitloops-embeddings --device mps
261274
```
262275

263276
## GitHub Actions

scripts/real_backend_smoke.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,19 @@ def main() -> None:
2121
required=True,
2222
help="Executable to invoke. This may be a console script name or an absolute path.",
2323
)
24+
parser.add_argument(
25+
"--device",
26+
default="auto",
27+
choices=("auto", "cpu", "mps"),
28+
help="Inference device override to pass through to the runtime.",
29+
)
2430
args = parser.parse_args()
2531

2632
binary = args.binary
27-
run_with_retries("embed smoke", lambda: run_embed_smoke(binary))
28-
run_with_retries("server smoke", lambda: run_server_smoke(binary, reserve_free_port()))
29-
run_with_retries("daemon smoke", lambda: run_daemon_smoke(binary))
33+
device = args.device
34+
run_with_retries("embed smoke", lambda: run_embed_smoke(binary, device))
35+
run_with_retries("server smoke", lambda: run_server_smoke(binary, reserve_free_port(), device))
36+
run_with_retries("daemon smoke", lambda: run_daemon_smoke(binary, device))
3037

3138

3239
def run_with_retries(name: str, operation) -> None:
@@ -48,9 +55,9 @@ def run_with_retries(name: str, operation) -> None:
4855
time.sleep(delay_seconds)
4956

5057

51-
def run_embed_smoke(binary: str) -> None:
58+
def run_embed_smoke(binary: str, device: str) -> None:
5259
completed = subprocess.run(
53-
[binary, "embed", "--model", "bge-m3", "--input", "Hello World"],
60+
[binary, "embed", "--model", "bge-m3", "--input", "Hello World", "--device", device],
5461
check=False,
5562
capture_output=True,
5663
text=True,
@@ -69,7 +76,7 @@ def run_embed_smoke(binary: str) -> None:
6976
raise RuntimeError("Embed smoke returned an empty embedding vector.")
7077

7178

72-
def run_server_smoke(binary: str, port: int) -> None:
79+
def run_server_smoke(binary: str, port: int, device: str) -> None:
7380
with tempfile.TemporaryDirectory(prefix="bitloops-embeddings-serve-logs-") as temp_dir:
7481
log_file = Path(temp_dir) / "serve.log"
7582
process = subprocess.Popen(
@@ -82,6 +89,8 @@ def run_server_smoke(binary: str, port: int) -> None:
8289
"127.0.0.1",
8390
"--port",
8491
str(port),
92+
"--device",
93+
device,
8594
"--log-file",
8695
str(log_file),
8796
],
@@ -109,7 +118,7 @@ def run_server_smoke(binary: str, port: int) -> None:
109118
process.wait(timeout=5)
110119

111120

112-
def run_daemon_smoke(binary: str) -> None:
121+
def run_daemon_smoke(binary: str, device: str) -> None:
113122
with tempfile.TemporaryDirectory(prefix="bitloops-embeddings-daemon-logs-") as temp_dir:
114123
log_file = Path(temp_dir) / "daemon.log"
115124
process = subprocess.Popen(
@@ -118,6 +127,8 @@ def run_daemon_smoke(binary: str) -> None:
118127
"daemon",
119128
"--model",
120129
"bge-m3",
130+
"--device",
131+
device,
121132
"--log-file",
122133
str(log_file),
123134
],

src/bitloops_embeddings/backend/sentence_transformers_backend.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from threading import RLock
88
from typing import Any
99

10-
from bitloops_embeddings.errors import BackendLoadError, InferenceError
10+
from bitloops_embeddings.errors import (
11+
BackendLoadError,
12+
InferenceError,
13+
UnsupportedDeviceError,
14+
)
1115
from bitloops_embeddings.logging_utils import LOGGER_NAME, log_event
1216

1317

@@ -23,13 +27,15 @@ def __init__(
2327
upstream_model_id: str,
2428
cache_dir: Path,
2529
dimensions: int,
30+
requested_device: str = "auto",
2631
) -> None:
2732
self._model_id = model_id
2833
self._upstream_model_id = upstream_model_id
2934
self._cache_dir = cache_dir
3035
self._dimensions = dimensions
3136
self._model: Any = None
32-
self._device = resolve_inference_device()
37+
self._requested_device = requested_device
38+
self._device = resolve_inference_device(requested_device=requested_device)
3339

3440
@property
3541
def model_id(self) -> str:
@@ -188,28 +194,53 @@ def _configure_tqdm_lock_for_single_process() -> None:
188194
_TQDM_THREAD_LOCK_CONFIGURED = True
189195

190196

191-
def resolve_inference_device() -> str:
192-
if platform.system() != "Darwin":
193-
return "cpu"
197+
def resolve_inference_device_for_request(requested_device: str) -> str:
198+
if requested_device == "auto":
199+
return "mps" if _is_mps_available() else "cpu"
194200

195-
if platform.machine().lower() not in ("arm64", "aarch64"):
201+
if requested_device == "cpu":
196202
return "cpu"
197203

204+
if requested_device == "mps":
205+
unavailable_reason = _resolve_mps_unavailable_reason()
206+
if unavailable_reason is None:
207+
return "mps"
208+
raise UnsupportedDeviceError(
209+
f"MPS was requested but is unavailable: {unavailable_reason}"
210+
)
211+
212+
raise UnsupportedDeviceError(
213+
f"Unsupported device '{requested_device}'. Supported devices: auto, cpu, mps."
214+
)
215+
216+
217+
def resolve_inference_device(requested_device: str = "auto") -> str:
218+
return resolve_inference_device_for_request(requested_device)
219+
220+
221+
def _is_mps_available() -> bool:
222+
return _resolve_mps_unavailable_reason() is None
223+
224+
225+
def _resolve_mps_unavailable_reason() -> str | None:
226+
if platform.system() != "Darwin":
227+
return "MPS is only available on macOS."
228+
198229
try:
199230
import torch
200231
except ImportError:
201-
return "cpu"
232+
return "PyTorch is not installed."
202233

203234
mps_backend = getattr(getattr(torch, "backends", None), "mps", None)
204235
if mps_backend is None:
205-
return "cpu"
236+
return "the installed PyTorch build does not expose torch.backends.mps."
206237

207238
is_built = getattr(mps_backend, "is_built", None)
208239
if callable(is_built) and not is_built():
209-
return "cpu"
240+
return "the installed PyTorch build was not built with MPS support."
210241

211242
is_available = getattr(mps_backend, "is_available", None)
212243
if callable(is_available) and is_available():
213-
return "mps"
244+
return None
214245

215-
return "cpu"
246+
return "macOS 12.3 or later and an MPS-enabled GPU are required."

src/bitloops_embeddings/cli.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ class LogLevel(str, Enum):
4040
ERROR = "error"
4141

4242

43+
class Device(str, Enum):
44+
AUTO = "auto"
45+
CPU = "cpu"
46+
MPS = "mps"
47+
48+
4349
class Transport(str, Enum):
4450
STDIO = "stdio"
4551

@@ -66,6 +72,14 @@ def embed(
6672
writable=True,
6773
),
6874
] = None,
75+
device: Annotated[
76+
Device,
77+
typer.Option(
78+
"--device",
79+
help="Inference device. auto prefers MPS when available, otherwise CPU.",
80+
case_sensitive=False,
81+
),
82+
] = Device.AUTO,
6983
output: Annotated[
7084
Optional[Path],
7185
typer.Option(
@@ -81,7 +95,7 @@ def embed(
8195
raise typer.BadParameter("Only JSON output is supported in v1.")
8296

8397
try:
84-
backend = _build_backend(model=model, cache_dir=cache_dir)
98+
backend = _build_backend(model=model, cache_dir=cache_dir, device=device)
8599
response = EmbeddingResponse(
86100
model_id=backend.model_id,
87101
dimensions=backend.dimensions,
@@ -110,6 +124,14 @@ def serve(
110124
writable=True,
111125
),
112126
] = None,
127+
device: Annotated[
128+
Device,
129+
typer.Option(
130+
"--device",
131+
help="Inference device. auto prefers MPS when available, otherwise CPU.",
132+
case_sensitive=False,
133+
),
134+
] = Device.AUTO,
113135
log_level: Annotated[
114136
LogLevel,
115137
typer.Option("--log-level", help="Server log verbosity.", case_sensitive=False),
@@ -130,7 +152,7 @@ def serve(
130152
) -> None:
131153
try:
132154
configure_logging(log_level.value, log_file=log_file, prefer_os_log=True)
133-
backend = _build_backend(model=model, cache_dir=cache_dir)
155+
backend = _build_backend(model=model, cache_dir=cache_dir, device=device)
134156
backend.load()
135157
app_instance = create_app(backend, max_batch_size=max_batch_size)
136158
log_event(
@@ -164,6 +186,14 @@ def daemon(
164186
writable=True,
165187
),
166188
] = None,
189+
device: Annotated[
190+
Device,
191+
typer.Option(
192+
"--device",
193+
help="Inference device. auto prefers MPS when available, otherwise CPU.",
194+
case_sensitive=False,
195+
),
196+
] = Device.AUTO,
167197
log_level: Annotated[
168198
LogLevel,
169199
typer.Option("--log-level", help="Daemon log verbosity.", case_sensitive=False),
@@ -182,7 +212,7 @@ def daemon(
182212
configure_logging(log_level.value, log_file=log_file, prefer_os_log=True)
183213
if transport is not Transport.STDIO:
184214
raise typer.BadParameter("Only stdio transport is supported in v1.")
185-
backend = _build_backend(model=model, cache_dir=cache_dir)
215+
backend = _build_backend(model=model, cache_dir=cache_dir, device=device)
186216
raise typer.Exit(code=run_daemon(backend))
187217
except typer.BadParameter:
188218
raise
@@ -226,10 +256,15 @@ def describe(
226256
_exit_with_error(BitloopsEmbeddingsError(f"Unexpected runtime error: {exc}"))
227257

228258

229-
def _build_backend(*, model: str, cache_dir: Optional[Path]) -> EmbeddingBackend:
259+
def _build_backend(
260+
*,
261+
model: str,
262+
cache_dir: Optional[Path],
263+
device: Device = Device.AUTO,
264+
) -> EmbeddingBackend:
230265
resolved_cache_dir = ensure_cache_dir(resolve_cache_dir(cache_dir))
231266
spec = get_model_spec(model)
232-
return spec.create_backend(resolved_cache_dir)
267+
return spec.create_backend(resolved_cache_dir, requested_device=device.value)
233268

234269

235270
def _emit_json(payload: str, *, output: Optional[Path]) -> None:

src/bitloops_embeddings/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ class UnsupportedModelError(BitloopsEmbeddingsError):
2424
default_status_code = 400
2525

2626

27+
class UnsupportedDeviceError(BitloopsEmbeddingsError):
28+
default_code = "unsupported_device"
29+
default_status_code = 400
30+
31+
2732
class BackendLoadError(BitloopsEmbeddingsError):
2833
default_code = "backend_load_error"
2934
default_status_code = 500

0 commit comments

Comments
 (0)