-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstreamlit_onestopqa.py
More file actions
387 lines (329 loc) · 15.3 KB
/
streamlit_onestopqa.py
File metadata and controls
387 lines (329 loc) · 15.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
"""
Streamlit explorer for OneStopQA dataset.
"""
from __future__ import annotations
import html
import json
import re
from pathlib import Path
from typing import Iterable, List, Sequence, Tuple
import streamlit as st
Span = Tuple[int, int]
@st.cache_data(show_spinner=False)
def load_onestop_qa() -> List[dict]:
"""Load the OneStopQA JSON payload once per session."""
data_path = Path(__file__).parent / "onestop_qa.json"
with data_path.open("r", encoding="utf-8") as f:
payload = json.load(f)
return payload.get("data", [])
def _as_pairs(spans: Sequence[Sequence[int]] | Sequence[int] | None) -> List[Span]:
"""Expand span definitions into (start, end) pairs.
Input can be either:
- list[list[int]] (usual case: one item per question)
- list[int] for a single question
Each item is [start, end] or [start1, end1, start2, end2].
Invalid or empty spans are ignored.
"""
pairs: List[Span] = []
if spans is None:
return pairs
# If spans is already a flat list of ints (single question), wrap it.
if spans and isinstance(spans[0], int): # type: ignore[index]
spans = [spans] # type: ignore[assignment]
for span in spans: # type: ignore[iteration-over-annotation]
if not isinstance(span, Sequence) or isinstance(span, (str, bytes)):
continue
if len(span) == 2:
pairs.append((int(span[0]), int(span[1])))
elif len(span) == 4:
pairs.append((int(span[0]), int(span[1])))
pairs.append((int(span[2]), int(span[3])))
return pairs
def _normalize_spans(spans: Iterable[Span], text_length: int, label: str, color: str):
"""Clamp spans to the text length and attach styling metadata."""
normalized = []
for start, end in spans:
lo, hi = sorted((max(0, start), min(text_length, end)))
if lo >= hi:
continue
normalized.append({
"start": lo,
"end": hi,
"label": label,
"color": color,
})
return sorted(normalized, key=lambda x: (x["start"], -(x["end"] - x["start"])))
def _get_span_colors_for_duplicates(a_spans_list: List[Sequence[Sequence[int]] | None], text: str) -> dict[Tuple[int, int], str]:
"""Identify duplicate critical spans across questions and return color mapping.
For spans that appear in 2+ questions (exactly or with overlaps):
- Overlapping regions: #FF8C00 (orange)
- Non-overlapping regions of overlapped spans: #FFA500 (lighter orange)
For spans in only 1 question: #20B2AA (teal)
"""
# Convert all spans to normalized character spans with question tracking
all_spans_by_question: list[list[dict]] = []
for q_idx, a_spans in enumerate(a_spans_list):
if a_spans is None:
all_spans_by_question.append([])
continue
word_spans = _as_pairs(a_spans)
char_spans = _word_spans_to_char_spans(word_spans, text)
normalized = _normalize_spans(char_spans, len(text), "critical", "#d4f5a3")
all_spans_by_question.append(normalized)
# Build a character-level map of which spans cover each position
char_to_spans: dict[int, list[Tuple[int, int, int]]] = {} # char_pos -> [(q_idx, start, end), ...]
for q_idx, spans in enumerate(all_spans_by_question):
for span in spans:
start, end = span["start"], span["end"]
for char_pos in range(start, end):
if char_pos not in char_to_spans:
char_to_spans[char_pos] = []
char_to_spans[char_pos].append((q_idx, start, end))
# Assign colors based on character coverage
span_colors: dict[Tuple[int, int], str] = {}
for q_idx, spans in enumerate(all_spans_by_question):
for span in spans:
start, end = span["start"], span["end"]
key = (start, end)
# Check if any character in this span is covered by spans from other questions
is_overlapped = False
for char_pos in range(start, end):
other_questions = set(
s[0] for s in char_to_spans.get(char_pos, []) if s[0] != q_idx
)
if other_questions:
is_overlapped = True
break
if is_overlapped:
# This span overlaps with spans from other questions
span_colors[key] = "#E65100" # Dark orange
else:
# Check if this span is unique to this question or duplicated exactly
question_count = len(set(
s[0] for s in char_to_spans.get(start, [])
if (s[1], s[2]) == (start, end)
))
if question_count >= 2:
span_colors[key] = "#E65100" # Dark orange for exact duplicates
else:
span_colors[key] = "#00838F" # Dark teal for unique
# Now handle partial overlaps by checking span boundaries
for q_idx, spans in enumerate(all_spans_by_question):
for span in spans:
start, end = span["start"], span["end"]
key = (start, end)
# Check for other spans that partially overlap but don't cover the whole span
for other_q_idx, other_spans in enumerate(all_spans_by_question):
if other_q_idx == q_idx:
continue
for other_span in other_spans:
other_start, other_end = other_span["start"], other_span["end"]
# Check for any overlap
if not (end <= other_start or start >= other_end):
# There's an overlap - mark this span as orange
span_colors[key] = "#E65100"
break
return span_colors
def _word_spans_to_char_spans(word_spans: List[Span], text: str) -> List[Span]:
"""Convert word-index spans (zero-based) to character spans."""
tokens = list(re.finditer(r"\S+", text))
if not tokens:
return []
char_spans: List[Span] = []
last_idx = len(tokens) - 1
for start_word, end_word in word_spans:
lo_idx = min(start_word, end_word)
hi_idx = max(start_word, end_word)
lo_idx = max(0, lo_idx)
hi_idx = min(last_idx, hi_idx)
if lo_idx > hi_idx:
continue
char_start = tokens[lo_idx].start()
char_end = tokens[hi_idx].end()
char_spans.append((char_start, char_end))
return char_spans
def highlight_text(text: str, a_spans: Sequence[Sequence[int]] | None, d_spans: Sequence[Sequence[int]] | None, span_colors: dict[Tuple[int, int], str] | None = None) -> str:
"""Return HTML with critical and distractor spans highlighted.
Args:
text: The text to highlight
a_spans: Critical spans (word indices)
d_spans: Distractor spans (word indices)
span_colors: Optional dict mapping (char_start, char_end) to color hex codes
"""
crit_words = _as_pairs(a_spans)
dist_words = _as_pairs(d_spans)
crit_char_spans = _word_spans_to_char_spans(crit_words, text)
dist_char_spans = _word_spans_to_char_spans(dist_words, text)
# Use custom colors if provided, otherwise use defaults
if span_colors:
crit = _normalize_spans(crit_char_spans, len(text), "critical", "#2E7D32")
dist = _normalize_spans(dist_char_spans, len(text), "distractor", "#D32F2F")
# Override colors from the span_colors dict
for span in crit + dist:
key = (span["start"], span["end"])
if key in span_colors:
span["color"] = span_colors[key]
ordered = sorted(crit + dist, key=lambda x: (x["start"], x["label"]))
else:
crit = _normalize_spans(crit_char_spans, len(text), "critical", "#2E7D32")
dist = _normalize_spans(dist_char_spans, len(text), "distractor", "#D32F2F")
ordered = sorted(crit + dist, key=lambda x: (x["start"], x["label"]))
parts: List[str] = []
cursor = 0
for span in ordered:
start, end = span["start"], span["end"]
# If spans overlap, trim the new span so we do not double paint.
if start < cursor:
start = cursor
if start >= end:
continue
if cursor < start:
parts.append(html.escape(text[cursor:start]))
span_text = html.escape(text[start:end])
parts.append(
f"<span style='color:{span['color']}; font-weight:bold;' title='{span['label']}'>"
f"{span_text}</span>"
)
cursor = end
if cursor < len(text):
parts.append(html.escape(text[cursor:]))
return "".join(parts)
def main():
st.set_page_config(page_title="OneStopQA Explorer", layout="wide")
st.title("OneStopQA Explorer")
st.caption("View passages, questions, and highlight critical/distractor spans.")
dataset = sorted(load_onestop_qa(), key=lambda x: x.get("article_id", ""))
if not dataset:
st.error("Could not load OneStopQA dataset.")
return
article_titles = [f"{item['title']} ({item['article_id']})" for item in dataset]
selected_article_label = st.selectbox("Article", article_titles, index=0)
selected_article_idx = article_titles.index(selected_article_label)
article = dataset[selected_article_idx]
# Reset navigation state when article changes
if "last_article_idx" not in st.session_state:
st.session_state.last_article_idx = selected_article_idx
if selected_article_idx != st.session_state.last_article_idx:
st.session_state.para_idx = 0
st.session_state.question_idx = 0
st.session_state.last_article_idx = selected_article_idx
paragraphs = article.get("paragraphs", [])
if not paragraphs:
st.warning("No paragraphs found for this article.")
return
# Initialize session state for navigation
if "para_idx" not in st.session_state:
st.session_state.para_idx = 0
if "question_idx" not in st.session_state:
st.session_state.question_idx = 0
if "view_mode" not in st.session_state:
st.session_state.view_mode = "By Question"
# View mode selector
view_mode = st.radio(
"View Mode",
["By Question", "All Critical Spans"],
horizontal=True,
key="view_mode_radio"
)
st.session_state.view_mode = view_mode
# Paragraph navigation buttons
para_col1, para_col3, para_col2 = st.columns([0.5, 0.5, 3])
with para_col1:
if st.button("← Prev", key="prev_para_btn", shortcut="A"):
st.session_state.para_idx = max(0, st.session_state.para_idx - 1)
st.session_state.question_idx = 0
with para_col2:
if st.button("Next →", key="next_para_btn", shortcut="D"):
st.session_state.para_idx = min(len(paragraphs) - 1, st.session_state.para_idx + 1)
st.session_state.question_idx = 0
with para_col3:
st.write(f"**Paragraph {st.session_state.para_idx + 1} / {len(paragraphs)}**")
para_idx = st.session_state.para_idx
para = paragraphs[para_idx]
# st.caption(f"Paragraph {para.get('paragraph_id', para_idx + 1)}")
qas = para.get("qas", [])
if not qas:
st.warning("No questions for this paragraph.")
return
# Question navigation buttons (only show in "By Question" mode)
if st.session_state.view_mode == "By Question":
q_col1, q_col3, q_col2 = st.columns([0.5, 0.5, 3])
with q_col1:
if st.button("← Prev", key="prev_q_btn", shortcut="W"):
st.session_state.question_idx = max(0, st.session_state.question_idx - 1)
with q_col2:
if st.button("Next →", key="next_q_btn", shortcut="S"):
st.session_state.question_idx = min(len(qas) - 1, st.session_state.question_idx + 1)
with q_col3:
q_ind = qas[st.session_state.question_idx].get('q_ind', st.session_state.question_idx)
st.write(f"**Question {st.session_state.question_idx + 1} / {len(qas)}**")
question_idx = st.session_state.question_idx
# st.caption(f"Question {qas[question_idx].get('q_ind', question_idx + 1)}")
selected_question = qas[question_idx] if qas else {}
# st.write("**Title**: " + article.get("title", ""))
for level in ["Adv", "Int", "Ele"]:
context_block = para.get(level, {})
context_text = context_block.get("context", "")
a_list = context_block.get("a_spans") or []
d_list = context_block.get("d_spans") or []
level_label = {
"Adv": "Advanced",
"Int": "Intermediate",
"Ele": "Elementary",
}.get(level, level)
st.subheader(f"{level_label}")
if context_text:
if st.session_state.view_mode == "By Question":
# Show only the current question's critical and distractor spans
a_spans = a_list[question_idx] if question_idx < len(a_list) else []
d_spans = d_list[question_idx] if question_idx < len(d_list) else []
html_block = highlight_text(context_text, a_spans, d_spans)
else: # All Critical Spans mode
# Collect all critical spans from all questions, no distractors
all_a_spans = []
for q_idx in range(len(qas)):
if q_idx < len(a_list) and a_list[q_idx]:
all_a_spans.append(a_list[q_idx])
else:
all_a_spans.append(None) # Preserve question order with None for missing spans
# Get color mapping for duplicates vs unique spans
span_colors = _get_span_colors_for_duplicates(all_a_spans, context_text)
html_block = highlight_text(context_text, all_a_spans, None, span_colors)
st.markdown(html_block, unsafe_allow_html=True)
else:
st.info("No context available for this level.")
# st.subhe ader("")
if st.session_state.view_mode == "By Question":
st.write("**Question**: " + selected_question.get("question", ""))
if selected_question.get("answers"):
st.markdown("**Answer choices**")
for idx, ans in enumerate(selected_question["answers"]):
st.write(f"{idx + 1}. {ans}")
else: # All Critical Spans mode
st.subheader("Questions")
for q_idx, qa in enumerate(qas):
q_ind = qa.get('q_ind', q_idx)
st.write(f"**Question {q_ind+1}**: {qa.get('question', '')}")
if qa.get("answers"):
for idx, ans in enumerate(qa["answers"]):
st.write(f" {idx + 1}. {ans}")
# Legend
if st.session_state.view_mode == "By Question":
st.markdown(
"<span style='color:#2E7D32; font-weight:bold;'>critical span</span> "
"<span style='color:#D32F2F; font-weight:bold;'>distractor span</span>",
unsafe_allow_html=True,
)
else:
st.markdown(
"<span style='color:#E65100; font-weight:bold;'>duplicate critical span</span> "
"<span style='color:#00838F; font-weight:bold;'>unique critical span</span>",
unsafe_allow_html=True,
)
# Keyboard shortcuts
st.markdown(
"<style>body { font-size: 12px; } p { margin: 0; }</style>",
unsafe_allow_html=True,
)
if __name__ == "__main__":
main()