|
| 1 | +import os |
| 2 | +import tempfile |
| 3 | +import time |
| 4 | +from typing import List, Optional, Tuple, Any |
| 5 | + |
| 6 | +import streamlit as st |
| 7 | +import requests |
| 8 | +import json |
| 9 | +import re |
| 10 | +from contextual import ContextualAI |
| 11 | + |
| 12 | + |
| 13 | +def init_session_state() -> None: |
| 14 | + if "api_key_submitted" not in st.session_state: |
| 15 | + st.session_state.api_key_submitted = False |
| 16 | + if "contextual_api_key" not in st.session_state: |
| 17 | + st.session_state.contextual_api_key = "" |
| 18 | + if "base_url" not in st.session_state: |
| 19 | + st.session_state.base_url = "https://api.contextual.ai/v1" |
| 20 | + if "agent_id" not in st.session_state: |
| 21 | + st.session_state.agent_id = "" |
| 22 | + if "datastore_id" not in st.session_state: |
| 23 | + st.session_state.datastore_id = "" |
| 24 | + if "chat_history" not in st.session_state: |
| 25 | + st.session_state.chat_history = [] |
| 26 | + if "processed_file" not in st.session_state: |
| 27 | + st.session_state.processed_file = False |
| 28 | + if "last_raw_response" not in st.session_state: |
| 29 | + st.session_state.last_raw_response = None |
| 30 | + if "last_user_query" not in st.session_state: |
| 31 | + st.session_state.last_user_query = "" |
| 32 | + |
| 33 | + |
| 34 | +def sidebar_api_form() -> bool: |
| 35 | + with st.sidebar: |
| 36 | + st.header("API & Resource Setup") |
| 37 | + |
| 38 | + if st.session_state.api_key_submitted: |
| 39 | + st.success("API verified") |
| 40 | + if st.button("Reset Setup"): |
| 41 | + st.session_state.clear() |
| 42 | + st.rerun() |
| 43 | + return True |
| 44 | + |
| 45 | + with st.form("contextual_api_form"): |
| 46 | + api_key = st.text_input("Contextual AI API Key", type="password") |
| 47 | + base_url = st.text_input( |
| 48 | + "Base URL", |
| 49 | + value=st.session_state.base_url, |
| 50 | + help="Include /v1 (e.g., https://api.contextual.ai/v1)", |
| 51 | + ) |
| 52 | + existing_agent_id = st.text_input("Existing Agent ID (optional)") |
| 53 | + existing_datastore_id = st.text_input("Existing Datastore ID (optional)") |
| 54 | + |
| 55 | + if st.form_submit_button("Save & Verify"): |
| 56 | + try: |
| 57 | + client = ContextualAI(api_key=api_key, base_url=base_url) |
| 58 | + _ = client.agents.list() |
| 59 | + |
| 60 | + st.session_state.contextual_api_key = api_key |
| 61 | + st.session_state.base_url = base_url |
| 62 | + st.session_state.agent_id = existing_agent_id |
| 63 | + st.session_state.datastore_id = existing_datastore_id |
| 64 | + st.session_state.api_key_submitted = True |
| 65 | + |
| 66 | + st.success("Credentials verified!") |
| 67 | + st.rerun() |
| 68 | + except Exception as e: |
| 69 | + st.error(f"Credential verification failed: {str(e)}") |
| 70 | + return False |
| 71 | + |
| 72 | + |
| 73 | +def ensure_client(): |
| 74 | + if not st.session_state.get("contextual_api_key"): |
| 75 | + raise ValueError("Contextual AI API key not provided") |
| 76 | + return ContextualAI(api_key=st.session_state.contextual_api_key, base_url=st.session_state.base_url) |
| 77 | + |
| 78 | + |
| 79 | +def create_datastore(client, name: str) -> Optional[str]: |
| 80 | + try: |
| 81 | + ds = client.datastores.create(name=name) |
| 82 | + return getattr(ds, "id", None) |
| 83 | + except Exception as e: |
| 84 | + st.error(f"Failed to create datastore: {e}") |
| 85 | + return None |
| 86 | + |
| 87 | + |
| 88 | +ALLOWED_EXTS = {".pdf", ".html", ".htm", ".mhtml", ".doc", ".docx", ".ppt", ".pptx"} |
| 89 | + |
| 90 | +def upload_documents(client, datastore_id: str, files: List[bytes], filenames: List[str], metadata: Optional[dict]) -> List[str]: |
| 91 | + doc_ids: List[str] = [] |
| 92 | + for content, fname in zip(files, filenames): |
| 93 | + try: |
| 94 | + ext = os.path.splitext(fname)[1].lower() |
| 95 | + if ext not in ALLOWED_EXTS: |
| 96 | + st.error(f"Unsupported file extension for {fname}. Allowed: {sorted(ALLOWED_EXTS)}") |
| 97 | + continue |
| 98 | + with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: |
| 99 | + tmp.write(content) |
| 100 | + tmp_path = tmp.name |
| 101 | + with open(tmp_path, "rb") as f: |
| 102 | + if metadata: |
| 103 | + result = client.datastores.documents.ingest(datastore_id, file=f, metadata=metadata) |
| 104 | + else: |
| 105 | + result = client.datastores.documents.ingest(datastore_id, file=f) |
| 106 | + doc_ids.append(getattr(result, "id", "")) |
| 107 | + except Exception as e: |
| 108 | + st.error(f"Failed to upload {fname}: {e}") |
| 109 | + finally: |
| 110 | + try: |
| 111 | + os.unlink(tmp_path) |
| 112 | + except Exception: |
| 113 | + pass |
| 114 | + return doc_ids |
| 115 | + |
| 116 | + |
| 117 | +def wait_until_documents_ready(api_key: str, datastore_id: str, base_url: str, max_checks: int = 30, interval_sec: float = 5.0) -> None: |
| 118 | + url = f"{base_url.rstrip('/')}/datastores/{datastore_id}/documents" |
| 119 | + headers = {"Authorization": f"Bearer {api_key}"} |
| 120 | + |
| 121 | + for _ in range(max_checks): |
| 122 | + try: |
| 123 | + resp = requests.get(url, headers=headers, timeout=30) |
| 124 | + if resp.status_code == 200: |
| 125 | + docs = resp.json().get("documents", []) |
| 126 | + if not any(d.get("status") in ("processing", "pending") for d in docs): |
| 127 | + return |
| 128 | + time.sleep(interval_sec) |
| 129 | + except Exception: |
| 130 | + time.sleep(interval_sec) |
| 131 | + |
| 132 | + |
| 133 | +def create_agent(client, name: str, description: str, datastore_id: str) -> Optional[str]: |
| 134 | + try: |
| 135 | + agent = client.agents.create(name=name, description=description, datastore_ids=[datastore_id]) |
| 136 | + return getattr(agent, "id", None) |
| 137 | + except Exception as e: |
| 138 | + st.error(f"Failed to create agent: {e}") |
| 139 | + return None |
| 140 | + |
| 141 | + |
| 142 | +def query_agent(client, agent_id: str, query: str) -> Tuple[str, Any]: |
| 143 | + try: |
| 144 | + resp = client.agents.query.create(agent_id=agent_id, messages=[{"role": "user", "content": query}]) |
| 145 | + if hasattr(resp, "content"): |
| 146 | + return resp.content, resp |
| 147 | + if hasattr(resp, "message") and hasattr(resp.message, "content"): |
| 148 | + return resp.message.content, resp |
| 149 | + if hasattr(resp, "messages") and resp.messages: |
| 150 | + last_msg = resp.messages[-1] |
| 151 | + return getattr(last_msg, "content", str(last_msg)), resp |
| 152 | + return str(resp), resp |
| 153 | + except Exception as e: |
| 154 | + return f"Error querying agent: {e}", None |
| 155 | + |
| 156 | + |
| 157 | +def show_retrieval_info(client, raw_response, agent_id: str) -> None: |
| 158 | + try: |
| 159 | + if not raw_response: |
| 160 | + st.info("No retrieval info available.") |
| 161 | + return |
| 162 | + message_id = getattr(raw_response, "message_id", None) |
| 163 | + retrieval_contents = getattr(raw_response, "retrieval_contents", []) |
| 164 | + if not message_id or not retrieval_contents: |
| 165 | + st.info("No retrieval metadata returned.") |
| 166 | + return |
| 167 | + first_content_id = getattr(retrieval_contents[0], "content_id", None) |
| 168 | + if not first_content_id: |
| 169 | + st.info("Missing content_id in retrieval metadata.") |
| 170 | + return |
| 171 | + ret_result = client.agents.query.retrieval_info(message_id=message_id, agent_id=agent_id, content_ids=[first_content_id]) |
| 172 | + metadatas = getattr(ret_result, "content_metadatas", []) |
| 173 | + if not metadatas: |
| 174 | + st.info("No content metadatas found.") |
| 175 | + return |
| 176 | + page_img_b64 = getattr(metadatas[0], "page_img", None) |
| 177 | + if not page_img_b64: |
| 178 | + st.info("No page image provided in metadata.") |
| 179 | + return |
| 180 | + import base64 |
| 181 | + img_bytes = base64.b64decode(page_img_b64) |
| 182 | + st.image(img_bytes, caption="Top Attribution Page", use_container_width=True) |
| 183 | + # Removed raw object rendering to keep UI clean |
| 184 | + except Exception as e: |
| 185 | + st.error(f"Failed to load retrieval info: {e}") |
| 186 | + |
| 187 | + |
| 188 | +def update_agent_prompt(client, agent_id: str, system_prompt: str) -> bool: |
| 189 | + try: |
| 190 | + client.agents.update(agent_id=agent_id, system_prompt=system_prompt) |
| 191 | + return True |
| 192 | + except Exception as e: |
| 193 | + st.error(f"Failed to update system prompt: {e}") |
| 194 | + return False |
| 195 | + |
| 196 | + |
| 197 | +def evaluate_with_lmunit(client, query: str, response_text: str, unit_test: str): |
| 198 | + try: |
| 199 | + result = client.lmunit.create(query=query, response=response_text, unit_test=unit_test) |
| 200 | + st.subheader("Evaluation Result") |
| 201 | + st.code(str(result), language="json") |
| 202 | + except Exception as e: |
| 203 | + st.error(f"LMUnit evaluation failed: {e}") |
| 204 | + |
| 205 | + |
| 206 | +def post_process_answer(text: str) -> str: |
| 207 | + text = re.sub(r"\(\s*\)", "", text) |
| 208 | + text = text.replace("• ", "\n- ") |
| 209 | + return text |
| 210 | + |
| 211 | + |
| 212 | +init_session_state() |
| 213 | + |
| 214 | +st.title("Contextual AI RAG Agent") |
| 215 | + |
| 216 | +if not sidebar_api_form(): |
| 217 | + st.info("Please enter your Contextual AI API key in the sidebar to continue.") |
| 218 | + st.stop() |
| 219 | + |
| 220 | +client = ensure_client() |
| 221 | + |
| 222 | +with st.expander("1) Create or Select Datastore", expanded=True): |
| 223 | + if not st.session_state.datastore_id: |
| 224 | + default_name = "contextualai_rag_datastore" |
| 225 | + ds_name = st.text_input("Datastore Name", value=default_name) |
| 226 | + if st.button("Create Datastore"): |
| 227 | + ds_id = create_datastore(client, ds_name) |
| 228 | + if ds_id: |
| 229 | + st.session_state.datastore_id = ds_id |
| 230 | + st.success(f"Created datastore: {ds_id}") |
| 231 | + else: |
| 232 | + st.success(f"Using Datastore: {st.session_state.datastore_id}") |
| 233 | + |
| 234 | +with st.expander("2) Upload Documents", expanded=True): |
| 235 | + uploaded_files = st.file_uploader("Upload PDFs or text files", type=["pdf", "txt", "md"], accept_multiple_files=True) |
| 236 | + metadata_json = st.text_area("Custom Metadata (JSON)", value="", placeholder='{"custom_metadata": {"field1": "value1"}}') |
| 237 | + if uploaded_files and st.session_state.datastore_id: |
| 238 | + contents = [f.getvalue() for f in uploaded_files] |
| 239 | + names = [f.name for f in uploaded_files] |
| 240 | + if st.button("Ingest Documents"): |
| 241 | + parsed_metadata = None |
| 242 | + if metadata_json.strip(): |
| 243 | + try: |
| 244 | + parsed_metadata = json.loads(metadata_json) |
| 245 | + except Exception as e: |
| 246 | + st.error(f"Invalid metadata JSON: {e}") |
| 247 | + parsed_metadata = None |
| 248 | + ids = upload_documents(client, st.session_state.datastore_id, contents, names, parsed_metadata) |
| 249 | + if ids: |
| 250 | + st.success(f"Uploaded {len(ids)} document(s)") |
| 251 | + wait_until_documents_ready(st.session_state.contextual_api_key, st.session_state.datastore_id, st.session_state.base_url) |
| 252 | + st.info("Documents are ready.") |
| 253 | + |
| 254 | +with st.expander("3) Create or Select Agent", expanded=True): |
| 255 | + if not st.session_state.agent_id and st.session_state.datastore_id: |
| 256 | + agent_name = st.text_input("Agent Name", value="ContextualAI RAG Agent") |
| 257 | + agent_desc = st.text_area("Agent Description", value="RAG agent over uploaded documents") |
| 258 | + if st.button("Create Agent"): |
| 259 | + a_id = create_agent(client, agent_name, agent_desc, st.session_state.datastore_id) |
| 260 | + if a_id: |
| 261 | + st.session_state.agent_id = a_id |
| 262 | + st.success(f"Created agent: {a_id}") |
| 263 | + elif st.session_state.agent_id: |
| 264 | + st.success(f"Using Agent: {st.session_state.agent_id}") |
| 265 | + |
| 266 | +with st.expander("4) Agent Settings (Optional)"): |
| 267 | + if st.session_state.agent_id: |
| 268 | + system_prompt_val = st.text_area("System Prompt", value="", placeholder="Paste a new system prompt to update your agent") |
| 269 | + if st.button("Update System Prompt") and system_prompt_val.strip(): |
| 270 | + ok = update_agent_prompt(client, st.session_state.agent_id, system_prompt_val.strip()) |
| 271 | + if ok: |
| 272 | + st.success("System prompt updated.") |
| 273 | + |
| 274 | +st.divider() |
| 275 | + |
| 276 | +for message in st.session_state.chat_history: |
| 277 | + with st.chat_message(message["role"]): |
| 278 | + st.markdown(message["content"]) |
| 279 | + |
| 280 | +query = st.chat_input("Ask a question about your documents") |
| 281 | +if query: |
| 282 | + st.session_state.last_user_query = query |
| 283 | + st.session_state.chat_history.append({"role": "user", "content": query}) |
| 284 | + with st.chat_message("user"): |
| 285 | + st.markdown(query) |
| 286 | + |
| 287 | + if st.session_state.agent_id: |
| 288 | + with st.chat_message("assistant"): |
| 289 | + answer, raw = query_agent(client, st.session_state.agent_id, query) |
| 290 | + st.session_state.last_raw_response = raw |
| 291 | + processed = post_process_answer(answer) |
| 292 | + st.markdown(processed) |
| 293 | + st.session_state.chat_history.append({"role": "assistant", "content": processed}) |
| 294 | + else: |
| 295 | + st.error("Please create or select an agent first.") |
| 296 | + |
| 297 | +with st.expander("Debug & Evaluation", expanded=False): |
| 298 | + st.caption("Tools to inspect retrievals and evaluate answers") |
| 299 | + if st.session_state.agent_id: |
| 300 | + if st.checkbox("Show Retrieval Info", value=False): |
| 301 | + show_retrieval_info(client, st.session_state.last_raw_response, st.session_state.agent_id) |
| 302 | + st.markdown("") |
| 303 | + unit_test = st.text_area("LMUnit rubric / unit test", value="Does the response avoid unnecessary information?", height=80) |
| 304 | + if st.button("Evaluate Last Answer with LMUnit"): |
| 305 | + if st.session_state.last_user_query and st.session_state.chat_history: |
| 306 | + last_assistant_msgs = [m for m in st.session_state.chat_history if m["role"] == "assistant"] |
| 307 | + if last_assistant_msgs: |
| 308 | + evaluate_with_lmunit(client, st.session_state.last_user_query, last_assistant_msgs[-1]["content"], unit_test) |
| 309 | + else: |
| 310 | + st.info("No assistant response to evaluate yet.") |
| 311 | + else: |
| 312 | + st.info("Ask a question first to run an evaluation.") |
| 313 | + |
| 314 | +with st.sidebar: |
| 315 | + st.divider() |
| 316 | + col1, col2 = st.columns(2) |
| 317 | + with col1: |
| 318 | + if st.button("Clear Chat"): |
| 319 | + st.session_state.chat_history = [] |
| 320 | + st.session_state.last_raw_response = None |
| 321 | + st.session_state.last_user_query = "" |
| 322 | + st.rerun() |
| 323 | + with col2: |
| 324 | + if st.button("Reset App"): |
| 325 | + st.session_state.clear() |
| 326 | + st.rerun() |
| 327 | + |
| 328 | + |
0 commit comments