Skip to content

Commit fab6673

Browse files
authored
Merge pull request #6 from Axect/feat/multi-custom-provider
feat: multi-custom provider system with TUI preset picker
2 parents 44b2047 + 19a3eeb commit fab6673

14 files changed

Lines changed: 638 additions & 55 deletions

File tree

src/arxiv_explorer/cli/daily.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ def show(
239239
arxiv_id, paper.title, paper.abstract, detailed=detailed, force=force
240240
)
241241

242+
if (summary or detailed) and paper_summary is None:
243+
import sys
244+
245+
print("Failed to generate summary (check provider settings)", file=sys.stderr)
246+
raise typer.Exit(1)
247+
242248
paper_translation = None
243249
if translate:
244250
translator = TranslationService()
@@ -252,6 +258,12 @@ def show(
252258
arxiv_id, paper.title, paper.abstract, force=force
253259
)
254260

261+
if translate and paper_translation is None:
262+
import sys
263+
264+
print("Failed to generate translation (check provider settings)", file=sys.stderr)
265+
raise typer.Exit(1)
266+
255267
print_paper_detail(paper, paper_summary, paper_translation)
256268

257269

src/arxiv_explorer/core/database.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@
100100
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
101101
);
102102
103+
-- Custom AI providers
104+
CREATE TABLE IF NOT EXISTS custom_providers (
105+
name TEXT PRIMARY KEY NOT NULL,
106+
preset TEXT NOT NULL,
107+
command_template TEXT NOT NULL,
108+
default_model TEXT DEFAULT '',
109+
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
110+
);
111+
103112
-- Paper review sections (incremental cache)
104113
CREATE TABLE IF NOT EXISTS paper_review_sections (
105114
id INTEGER PRIMARY KEY AUTOINCREMENT,

src/arxiv_explorer/core/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ class Language(str, Enum):
3838
KO = "ko"
3939

4040

41+
@dataclass
42+
class CustomProviderConfig:
43+
name: str
44+
preset: str
45+
command_template: str
46+
default_model: str = ""
47+
48+
4149
class JobType(Enum):
4250
SUMMARIZE = "summarize"
4351
TRANSLATE = "translate"

src/arxiv_explorer/services/providers.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,29 @@ def build_command(self, prompt: str, model: str = "") -> list[str]:
161161
}
162162

163163

164-
def get_provider(provider_type: AIProviderType) -> AIProvider:
165-
"""Return a provider instance. If custom, load the template from settings."""
166-
provider = PROVIDERS[provider_type]
167-
if provider_type == AIProviderType.CUSTOM:
168-
from .settings_service import SettingsService
169-
170-
template = SettingsService().get("custom_command")
171-
provider.configure(template)
172-
return provider
164+
def get_provider(provider_name: str | AIProviderType) -> AIProvider:
165+
"""Return a provider instance. Checks built-in registry first, then custom_providers table."""
166+
# Normalize to string
167+
name = provider_name.value if isinstance(provider_name, AIProviderType) else provider_name
168+
169+
# Try built-in registry
170+
for ptype, prov in PROVIDERS.items():
171+
if ptype.value == name:
172+
if ptype == AIProviderType.CUSTOM:
173+
from .settings_service import SettingsService
174+
175+
template = SettingsService().get("custom_command")
176+
prov.configure(template)
177+
return prov
178+
179+
# Try custom_providers table
180+
from .settings_service import SettingsService
181+
182+
for cp in SettingsService().get_custom_providers():
183+
if cp.name == name:
184+
provider = CustomProvider()
185+
provider.configure(cp.command_template)
186+
return provider
187+
188+
# Fallback to gemini
189+
return PROVIDERS[AIProviderType.GEMINI]

src/arxiv_explorer/services/settings_service.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def get_all(self) -> dict[str, str]:
7070
settings[row["key"]] = row["value"]
7171
return settings
7272

73-
def get_provider(self) -> AIProviderType:
74-
"""Get the current AI provider."""
75-
return AIProviderType(self.get("ai_provider"))
73+
def get_provider(self) -> str:
74+
"""Return the active provider name as a string."""
75+
return self.get("ai_provider")
7676

7777
def get_model(self) -> str:
7878
"""Get the current AI model override."""
@@ -104,3 +104,47 @@ def set_weights(self, weights: dict[str, int]) -> None:
104104
def reset_weights(self) -> None:
105105
"""Reset recommendation weights to defaults."""
106106
self.set_weights(DEFAULT_WEIGHTS)
107+
108+
# Reserved names that cannot be used for custom providers
109+
RESERVED_PROVIDERS = {"gemini", "claude", "openai", "ollama", "opencode", "custom"}
110+
111+
def get_custom_providers(self) -> list:
112+
"""Return all custom providers as list of CustomProviderConfig."""
113+
from ..core.models import CustomProviderConfig
114+
115+
with get_connection() as conn:
116+
rows = conn.execute(
117+
"SELECT name, preset, command_template, default_model FROM custom_providers ORDER BY name"
118+
).fetchall()
119+
return [
120+
CustomProviderConfig(
121+
name=r["name"],
122+
preset=r["preset"],
123+
command_template=r["command_template"],
124+
default_model=r["default_model"] or "",
125+
)
126+
for r in rows
127+
]
128+
129+
def add_custom_provider(
130+
self, name: str, preset: str, command_template: str, default_model: str = ""
131+
) -> None:
132+
"""Register a custom provider. Raises ValueError if name is reserved or duplicate."""
133+
if name.lower() in self.RESERVED_PROVIDERS:
134+
raise ValueError(f"'{name}' is a reserved provider name")
135+
with get_connection() as conn:
136+
conn.execute(
137+
"INSERT OR REPLACE INTO custom_providers (name, preset, command_template, default_model) "
138+
"VALUES (?, ?, ?, ?)",
139+
(name, preset, command_template, default_model),
140+
)
141+
conn.commit()
142+
143+
def remove_custom_provider(self, name: str) -> None:
144+
"""Remove a custom provider. If it's the active provider, switch to gemini."""
145+
with get_connection() as conn:
146+
conn.execute("DELETE FROM custom_providers WHERE name = ?", (name,))
147+
conn.commit()
148+
# If active provider was deleted, reset to gemini
149+
if self.get("ai_provider") == name:
150+
self.set("ai_provider", "gemini")

src/arxiv_explorer/services/summarization.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,19 @@ def summarize(
6363
settings = SettingsService()
6464
provider = get_provider(settings.get_provider())
6565
if not provider.is_available():
66+
import sys
67+
68+
print("Summary generation failed: provider not available", file=sys.stderr)
6669
return None
6770
output = provider.invoke(
6871
prompt,
6972
model=settings.get_model(),
7073
timeout=settings.get_timeout(),
7174
)
7275
if output is None:
76+
import sys
77+
78+
print("Summary generation failed: provider returned no output", file=sys.stderr)
7379
return None
7480
# Extract JSON block (may be in ```json ... ``` format)
7581
if "```json" in output:
@@ -82,13 +88,9 @@ def summarize(
8288
try:
8389
data = json.loads(output)
8490
except json.JSONDecodeError as e:
85-
# JSON parse failure - print debug info and return None
8691
import sys
8792

88-
if "--verbose" in sys.argv or "-v" in sys.argv:
89-
print(f"\nSummary generation failed ({arxiv_id}): JSON parse error")
90-
print(f"Error: {e}")
91-
print(f"Output sample: {output[:300]}...")
93+
print(f"Summary generation failed: JSON parse error: {e}", file=sys.stderr)
9294
return None
9395

9496
summary = PaperSummary(
@@ -106,11 +108,9 @@ def summarize(
106108
return summary
107109

108110
except Exception as e:
109-
# Other error - fail silently
110111
import sys
111112

112-
if "--verbose" in sys.argv or "-v" in sys.argv:
113-
print(f"\nError during summary generation ({arxiv_id}): {e}")
113+
print(f"Summary generation failed: {e}", file=sys.stderr)
114114
return None
115115

116116
def _get_cached(self, arxiv_id: str) -> PaperSummary | None:

src/arxiv_explorer/services/translation.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,19 @@ def translate(
7272
settings = SettingsService()
7373
provider = get_provider(settings.get_provider())
7474
if not provider.is_available():
75+
import sys
76+
77+
print("Translation failed: provider not available", file=sys.stderr)
7578
return None
7679
output = provider.invoke(
7780
prompt,
7881
model=settings.get_model(),
7982
timeout=settings.get_timeout(),
8083
)
8184
if output is None:
85+
import sys
86+
87+
print("Translation failed: provider returned no output", file=sys.stderr)
8288
return None
8389

8490
# Extract JSON block
@@ -94,10 +100,7 @@ def translate(
94100
except json.JSONDecodeError as e:
95101
import sys
96102

97-
if "--verbose" in sys.argv or "-v" in sys.argv:
98-
print(f"\nTranslation failed ({arxiv_id}): JSON parse error")
99-
print(f"Error: {e}")
100-
print(f"Output sample: {output[:300]}...")
103+
print(f"Translation failed: JSON parse error: {e}", file=sys.stderr)
101104
return None
102105

103106
translation = PaperTranslation(
@@ -115,8 +118,7 @@ def translate(
115118
except Exception as e:
116119
import sys
117120

118-
if "--verbose" in sys.argv or "-v" in sys.argv:
119-
print(f"\nTranslation error ({arxiv_id}): {e}")
121+
print(f"Translation failed: {e}", file=sys.stderr)
120122
return None
121123

122124
def _get_cached(self, arxiv_id: str, target_language: Language) -> PaperTranslation | None:

tests/test_database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"paper_review_sections",
2222
"preferred_authors",
2323
"daily_fetch_cache",
24+
"custom_providers",
2425
}
2526

2627

tui-rs/src/app.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ pub struct PrefsState {
191191
pub weights: [i64; 4],
192192
pub provider: String,
193193
pub language: String,
194+
pub custom_providers: Vec<crate::db::models::CustomProviderEntry>,
195+
pub custom_provider_selected: usize,
194196
pub selected: usize, // cursor in weights section
195197
pub focus_section: usize, // 0=cats, 1=keywords, 2=authors, 3=weights, 4=config
196198
pub section_selected: [usize; 5], // cursor per section (section 4: 0=provider, 1=language)
@@ -205,6 +207,8 @@ impl Default for PrefsState {
205207
weights: [60, 20, 15, 5],
206208
provider: "gemini".to_string(),
207209
language: "en".to_string(),
210+
custom_providers: vec![],
211+
custom_provider_selected: 0,
208212
selected: 0,
209213
focus_section: 0,
210214
section_selected: [0; 5],
@@ -220,6 +224,7 @@ impl Default for PrefsState {
220224
pub enum ConfirmAction {
221225
RegenerateSummary,
222226
RegenerateTranslation,
227+
RemoveCustomProvider(String),
223228
}
224229

225230
// =============================================================================
@@ -251,6 +256,18 @@ pub enum OverlayMode {
251256
AuthorInput {
252257
text: String,
253258
},
259+
PresetPicker {
260+
selected: usize,
261+
},
262+
ProviderNameInput {
263+
preset: String,
264+
text: String,
265+
},
266+
CommandTemplateInput {
267+
preset: String,
268+
name: String,
269+
text: String,
270+
},
254271
}
255272

256273
// =============================================================================
@@ -353,6 +370,8 @@ impl App {
353370
weights,
354371
provider,
355372
language,
373+
custom_providers: vec![],
374+
custom_provider_selected: 0,
356375
selected: 0,
357376
focus_section: 0,
358377
section_selected: [0; 5],

tui-rs/src/db/mod.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ impl Database {
2424
}
2525
let conn = Connection::open(path)?;
2626
conn.execute_batch("PRAGMA foreign_keys = ON; PRAGMA journal_mode = WAL;")?;
27-
Ok(Self { conn })
27+
let db = Self { conn };
28+
db.ensure_custom_providers_table()?;
29+
Ok(db)
2830
}
2931

3032
/// Return the default database path.
@@ -475,6 +477,56 @@ impl Database {
475477
Ok(())
476478
}
477479

480+
// =========================================================================
481+
// Custom Providers
482+
// =========================================================================
483+
484+
pub fn ensure_custom_providers_table(&self) -> Result<()> {
485+
self.conn.execute_batch(
486+
"CREATE TABLE IF NOT EXISTS custom_providers (
487+
name TEXT PRIMARY KEY NOT NULL,
488+
preset TEXT NOT NULL,
489+
command_template TEXT NOT NULL,
490+
default_model TEXT DEFAULT '',
491+
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
492+
);"
493+
)?;
494+
Ok(())
495+
}
496+
497+
pub fn get_custom_providers(&self) -> Result<Vec<CustomProviderEntry>> {
498+
let mut stmt = self.conn.prepare(
499+
"SELECT name, preset, command_template, default_model FROM custom_providers ORDER BY name"
500+
)?;
501+
let rows = stmt.query_map([], |row| {
502+
Ok(CustomProviderEntry {
503+
name: row.get(0)?,
504+
preset: row.get(1)?,
505+
command_template: row.get(2)?,
506+
default_model: row.get::<_, Option<String>>(3)?.unwrap_or_default(),
507+
})
508+
})?;
509+
rows.collect()
510+
}
511+
512+
pub fn add_custom_provider(&self, entry: &CustomProviderEntry) -> Result<()> {
513+
self.conn.execute(
514+
"INSERT OR REPLACE INTO custom_providers (name, preset, command_template, default_model) VALUES (?1, ?2, ?3, ?4)",
515+
params![entry.name, entry.preset, entry.command_template, entry.default_model],
516+
)?;
517+
Ok(())
518+
}
519+
520+
pub fn remove_custom_provider(&self, name: &str) -> Result<()> {
521+
self.conn.execute("DELETE FROM custom_providers WHERE name = ?1", params![name])?;
522+
// If active provider was deleted, reset to gemini
523+
let current = self.get_setting("ai_provider", "gemini")?;
524+
if current == name {
525+
self.set_setting("ai_provider", "gemini")?;
526+
}
527+
Ok(())
528+
}
529+
478530
// =========================================================================
479531
// Summaries & Translations
480532
// =========================================================================

0 commit comments

Comments
 (0)