1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4 Arc,
5 atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use chrono::Utc;
11use futures::{Stream, TryStreamExt};
12use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
13use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
14use headless_lms_utils::{ApplicationConfiguration, http::REQWEST_CLIENT};
15use pin_project::pin_project;
16use serde::{Deserialize, Serialize};
17use sqlx::PgPool;
18use tokio::{io::AsyncBufReadExt, sync::Mutex};
19use tokio_util::io::StreamReader;
20use url::Url;
21
22use crate::llm_utils::{LLM_API_VERSION, build_llm_headers, estimate_tokens};
23use crate::prelude::*;
24use crate::search_filter::SearchFilter;
25
26const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
27
28#[derive(Deserialize, Serialize, Debug)]
29pub struct ContentFilterResults {
30 pub hate: Option<ContentFilter>,
31 pub self_harm: Option<ContentFilter>,
32 pub sexual: Option<ContentFilter>,
33 pub violence: Option<ContentFilter>,
34}
35
36#[derive(Deserialize, Serialize, Debug)]
37pub struct ContentFilter {
38 pub filtered: bool,
39 pub severity: String,
40}
41
42#[derive(Deserialize, Serialize, Debug)]
43pub struct Choice {
44 pub content_filter_results: Option<ContentFilterResults>,
45 pub delta: Option<Delta>,
46 pub finish_reason: Option<String>,
47 pub index: i32,
48}
49
50#[derive(Deserialize, Serialize, Debug)]
51pub struct Delta {
52 pub content: Option<String>,
53 pub context: Option<DeltaContext>,
54}
55
56#[derive(Deserialize, Serialize, Debug)]
57pub struct DeltaContext {
58 pub citations: Vec<Citation>,
59}
60
61#[derive(Deserialize, Serialize, Debug)]
62pub struct Citation {
63 pub content: String,
64 pub title: String,
65 pub url: String,
66 pub filepath: String,
67}
68
69#[derive(Deserialize, Serialize, Debug)]
70pub struct ResponseChunk {
71 pub choices: Vec<Choice>,
72 pub created: u64,
73 pub id: String,
74 pub model: String,
75 pub object: String,
76 pub system_fingerprint: Option<String>,
77}
78
79#[derive(Serialize, Deserialize, Debug)]
81pub struct ApiChatMessage {
82 pub role: String,
83 pub content: String,
84}
85
86impl From<ChatbotConversationMessage> for ApiChatMessage {
87 fn from(message: ChatbotConversationMessage) -> Self {
88 ApiChatMessage {
89 role: if message.is_from_chatbot {
90 "assistant".to_string()
91 } else {
92 "user".to_string()
93 },
94 content: message.message.unwrap_or_default(),
95 }
96 }
97}
98
99#[derive(Serialize, Deserialize, Debug)]
100pub struct ChatRequest {
101 pub messages: Vec<ApiChatMessage>,
102 #[serde(skip_serializing_if = "Vec::is_empty")]
103 pub data_sources: Vec<DataSource>,
104 pub temperature: f32,
105 pub top_p: f32,
106 pub frequency_penalty: f32,
107 pub presence_penalty: f32,
108 pub max_tokens: i32,
109 pub stop: Option<String>,
110 pub stream: bool,
111}
112
113impl ChatRequest {
114 pub async fn build_and_insert_incoming_message_to_db(
115 conn: &mut PgConnection,
116 chatbot_configuration_id: Uuid,
117 conversation_id: Uuid,
118 message: &str,
119 app_config: &ApplicationConfiguration,
120 ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
121 let index_name = Url::parse(&app_config.base_url)?
122 .host_str()
123 .expect("BASE_URL must have a host")
124 .replace(".", "-");
125
126 let configuration =
127 models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
128
129 let conversation_messages =
130 models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
131 .await?;
132
133 let new_order_number = conversation_messages
134 .iter()
135 .map(|m| m.order_number)
136 .max()
137 .unwrap_or(0)
138 + 1;
139
140 let new_message = models::chatbot_conversation_messages::insert(
141 conn,
142 ChatbotConversationMessage {
143 id: Uuid::new_v4(),
144 created_at: Utc::now(),
145 updated_at: Utc::now(),
146 deleted_at: None,
147 conversation_id,
148 message: Some(message.to_string()),
149 is_from_chatbot: false,
150 message_is_complete: true,
151 used_tokens: estimate_tokens(message),
152 order_number: new_order_number,
153 },
154 )
155 .await?;
156
157 let mut api_chat_messages: Vec<ApiChatMessage> =
158 conversation_messages.into_iter().map(Into::into).collect();
159
160 api_chat_messages.push(new_message.clone().into());
161
162 api_chat_messages.insert(
163 0,
164 ApiChatMessage {
165 role: "system".to_string(),
166 content: configuration.prompt.clone(),
167 },
168 );
169
170 let data_sources = if configuration.use_azure_search {
171 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
172 anyhow::anyhow!("Azure configuration is missing from the application configuration")
173 })?;
174
175 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
176 anyhow::anyhow!(
177 "Azure search configuration is missing from the Azure configuration"
178 )
179 })?;
180
181 let query_type = if configuration.use_semantic_reranking {
182 "vector_semantic_hybrid"
183 } else {
184 "vector_simple_hybrid"
185 };
186
187 vec![DataSource {
188 data_type: "azure_search".to_string(),
189 parameters: DataSourceParameters {
190 endpoint: search_config.search_endpoint.to_string(),
191 authentication: DataSourceParametersAuthentication {
192 auth_type: "api_key".to_string(),
193 key: search_config.search_api_key.clone(),
194 },
195 index_name,
196 query_type: query_type.to_string(),
197 semantic_configuration: "default".to_string(),
198 embedding_dependency: EmbeddingDependency {
199 dep_type: "deployment_name".to_string(),
200 deployment_name: search_config.vectorizer_deployment_id.clone(),
201 },
202 in_scope: false,
203 top_n_documents: 15,
204 strictness: 3,
205 filter: Some(
206 SearchFilter::eq("course_id", configuration.course_id.to_string())
207 .to_odata()?,
208 ),
209 fields_mapping: FieldsMapping {
210 content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
211 content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
212 filepath_field: "filepath".to_string(),
213 title_field: "title".to_string(),
214 url_field: "url".to_string(),
215 vector_fields: vec!["text_vector".to_string()],
216 },
217 },
218 }]
219 } else {
220 Vec::new()
221 };
222
223 let serialized_messages = serde_json::to_string(&api_chat_messages)?;
224 let request_estimated_tokens = estimate_tokens(&serialized_messages);
225
226 Ok((
227 Self {
228 messages: api_chat_messages,
229 data_sources,
230 temperature: configuration.temperature,
231 top_p: configuration.top_p,
232 frequency_penalty: configuration.frequency_penalty,
233 presence_penalty: configuration.presence_penalty,
234 max_tokens: configuration.response_max_tokens,
235 stop: None,
236 stream: true,
237 },
238 new_message,
239 request_estimated_tokens,
240 ))
241 }
242}
243
244#[derive(Serialize, Deserialize, Debug)]
245pub struct DataSource {
246 #[serde(rename = "type")]
247 pub data_type: String,
248 pub parameters: DataSourceParameters,
249}
250
251#[derive(Serialize, Deserialize, Debug)]
252pub struct DataSourceParameters {
253 pub endpoint: String,
254 pub authentication: DataSourceParametersAuthentication,
255 pub index_name: String,
256 pub query_type: String,
257 pub embedding_dependency: EmbeddingDependency,
258 pub in_scope: bool,
259 pub top_n_documents: i32,
260 pub strictness: i32,
261 #[serde(skip_serializing_if = "Option::is_none")]
262 pub filter: Option<String>,
263 pub fields_mapping: FieldsMapping,
264 pub semantic_configuration: String,
265}
266
267#[derive(Serialize, Deserialize, Debug)]
268pub struct DataSourceParametersAuthentication {
269 #[serde(rename = "type")]
270 pub auth_type: String,
271 pub key: String,
272}
273
274#[derive(Serialize, Deserialize, Debug)]
275pub struct EmbeddingDependency {
276 #[serde(rename = "type")]
277 pub dep_type: String,
278 pub deployment_name: String,
279}
280
281#[derive(Serialize, Deserialize, Debug)]
282pub struct FieldsMapping {
283 pub content_fields_separator: String,
284 pub content_fields: Vec<String>,
285 pub filepath_field: String,
286 pub title_field: String,
287 pub url_field: String,
288 pub vector_fields: Vec<String>,
289}
290
291#[derive(Serialize, Deserialize, Debug)]
292pub struct ChatResponse {
293 pub text: String,
294}
295
296#[pin_project]
298struct GuardedStream<S> {
299 guard: RequestCancelledGuard,
300 #[pin]
301 stream: S,
302}
303
304impl<S> GuardedStream<S> {
305 fn new(guard: RequestCancelledGuard, stream: S) -> Self {
306 Self { guard, stream }
307 }
308}
309
310impl<S> Stream for GuardedStream<S>
311where
312 S: Stream<Item = anyhow::Result<Bytes>> + Send,
313{
314 type Item = S::Item;
315
316 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
317 let this = self.project();
318 this.stream.poll_next(cx)
319 }
320}
321
322struct RequestCancelledGuard {
323 response_message_id: Uuid,
324 received_string: Arc<Mutex<Vec<String>>>,
325 pool: PgPool,
326 done: Arc<AtomicBool>,
327 request_estimated_tokens: i32,
328}
329
330impl Drop for RequestCancelledGuard {
331 fn drop(&mut self) {
332 if self.done.load(atomic::Ordering::Relaxed) {
333 return;
334 }
335 warn!("Request was not cancelled. Cleaning up.");
336 let response_message_id = self.response_message_id;
337 let received_string = self.received_string.clone();
338 let pool = self.pool.clone();
339 let request_estimated_tokens = self.request_estimated_tokens;
340 tokio::spawn(async move {
341 info!("Verifying the received message has been handled");
342 let mut conn = pool.acquire().await.expect("Could not acquire connection");
343 let full_response_text = received_string.lock().await;
344 if full_response_text.is_empty() {
345 info!("No response received. Deleting the response message");
346 models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
347 .await
348 .expect("Could not delete response message");
349 return;
350 }
351 info!("Response received but not completed. Saving the text received so far.");
352 let full_response_as_string = full_response_text.join("");
353 let estimated_cost = estimate_tokens(&full_response_as_string);
354 info!(
355 "End of chatbot response stream. Estimated cost: {}. Response: {}",
356 estimated_cost, full_response_as_string
357 );
358
359 models::chatbot_conversation_messages::update(
361 &mut conn,
362 response_message_id,
363 &full_response_as_string,
364 true,
365 request_estimated_tokens + estimated_cost,
366 )
367 .await
368 .expect("Could not update response message");
369 });
370 }
371}
372
373pub async fn send_chat_request_and_parse_stream(
374 conn: &mut PgConnection,
375 pool: PgPool,
376 app_config: &ApplicationConfiguration,
377 chatbot_configuration_id: Uuid,
378 conversation_id: Uuid,
379 message: &str,
380) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
381 let (chat_request, new_message, request_estimated_tokens) =
382 ChatRequest::build_and_insert_incoming_message_to_db(
383 conn,
384 chatbot_configuration_id,
385 conversation_id,
386 message,
387 app_config,
388 )
389 .await?;
390
391 let full_response_text = Arc::new(Mutex::new(Vec::new()));
392 let done = Arc::new(AtomicBool::new(false));
393
394 let azure_config = app_config
395 .azure_configuration
396 .as_ref()
397 .ok_or_else(|| anyhow::anyhow!("Azure configuration not found"))?;
398
399 let chatbot_config = azure_config
400 .chatbot_config
401 .as_ref()
402 .ok_or_else(|| anyhow::anyhow!("Chatbot configuration not found"))?;
403
404 let api_key = chatbot_config.api_key.clone();
405 let mut url = chatbot_config.api_endpoint.clone();
406
407 url.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
409
410 let headers = build_llm_headers(&api_key)?;
411
412 let response_order_number = new_message.order_number + 1;
413
414 let response_message = models::chatbot_conversation_messages::insert(
415 conn,
416 ChatbotConversationMessage {
417 id: Uuid::new_v4(),
418 created_at: Utc::now(),
419 updated_at: Utc::now(),
420 deleted_at: None,
421 conversation_id,
422 message: None,
423 is_from_chatbot: true,
424 message_is_complete: false,
425 used_tokens: request_estimated_tokens,
426 order_number: response_order_number,
427 },
428 )
429 .await?;
430
431 let guard = RequestCancelledGuard {
433 response_message_id: response_message.id,
434 received_string: full_response_text.clone(),
435 pool: pool.clone(),
436 done: done.clone(),
437 request_estimated_tokens,
438 };
439
440 let request = REQWEST_CLIENT
441 .post(url)
442 .headers(headers)
443 .json(&chat_request)
444 .send();
445
446 let response = request.await?;
447
448 info!("Receiving chat response with {:?}", response.version());
449
450 if !response.status().is_success() {
451 let status = response.status();
452 let error_message = response.text().await?;
453 return Err(anyhow::anyhow!(
454 "Failed to send chat request. Status: {}. Error: {}",
455 status,
456 error_message
457 ));
458 }
459
460 let stream = response.bytes_stream().map_err(std::io::Error::other);
461 let reader = StreamReader::new(stream);
462 let mut lines = reader.lines();
463
464 let response_stream = async_stream::try_stream! {
465 while let Some(line) = lines.next_line().await? {
466 if !line.starts_with("data: ") {
467 continue;
468 }
469 let json_str = line.trim_start_matches("data: ");
470
471 let mut full_response_text = full_response_text.lock().await;
472 if json_str.trim() == "[DONE]" {
473 let full_response_as_string = full_response_text.join("");
474 let estimated_cost = estimate_tokens(&full_response_as_string);
475 info!(
476 "End of chatbot response stream. Estimated cost: {}. Response: {}",
477 estimated_cost, full_response_as_string
478 );
479 done.store(true, atomic::Ordering::Relaxed);
480 let mut conn = pool.acquire().await?;
481 models::chatbot_conversation_messages::update(
482 &mut conn,
483 response_message.id,
484 &full_response_as_string,
485 true,
486 request_estimated_tokens + estimated_cost,
487 ).await?;
488 break;
489 }
490 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
491 anyhow::anyhow!("Failed to parse response chunk: {}", e)
492 })?;
493
494 for choice in &response_chunk.choices {
495 if let Some(delta) = &choice.delta {
496 if let Some(content) = &delta.content {
497 full_response_text.push(content.clone());
498 let response = ChatResponse { text: content.clone() };
499 let response_as_string = serde_json::to_string(&response)?;
500 yield Bytes::from(response_as_string);
501 yield Bytes::from("\n");
502 }
503 if let Some(context) = &delta.context {
504 let citation_message_id = response_message.id;
505 let mut conn = pool.acquire().await?;
506 for (idx, cit) in context.citations.iter().enumerate() {
507 let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
508 let split = content.split_once(CONTENT_FIELD_SEPARATOR);
509 if split.is_none() {
510 error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
511 }
512 let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
513
514 let document_url = cit.url.clone();
515 let mut page_path = PathBuf::from(&cit.filepath);
516 page_path.set_extension("");
517 let page_id_str = page_path.file_name();
518 let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
519 let course_material_chapter_number = if let Some(id) = page_id {
520 let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
521 chapter.map(|c| c.chapter_number)
522 } else {
523 None
524 };
525
526 models::chatbot_conversation_messages_citations::insert(
527 &mut conn, ChatbotConversationMessageCitation {
528 id: Uuid::new_v4(),
529 created_at: Utc::now(),
530 updated_at: Utc::now(),
531 deleted_at: None,
532 conversation_message_id: citation_message_id,
533 conversation_id,
534 course_material_chapter_number,
535 title: cit.title.clone(),
536 content: cleaned_content,
537 document_url,
538 citation_number: (idx+1) as i32,
539 }
540 ).await?;
541 }
542 }
543
544 }
545 }
546 }
547
548 if !done.load(atomic::Ordering::Relaxed) {
549 Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
550 }
551 };
552
553 let guarded_stream = GuardedStream::new(guard, response_stream);
556
557 Ok(Box::pin(guarded_stream))
559}