1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4 Arc,
5 atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use chrono::Utc;
11use futures::{Stream, TryStreamExt};
12use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
13use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
14use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
15use headless_lms_utils::ApplicationConfiguration;
16use pin_project::pin_project;
17use serde::{Deserialize, Serialize};
18use sqlx::PgPool;
19use tokio::{io::AsyncBufReadExt, sync::Mutex};
20use tokio_util::io::StreamReader;
21use url::Url;
22
23use headless_lms_utils::url_encoding::url_decode;
24
25use crate::llm_utils::{APIMessage, MessageRole, estimate_tokens, make_streaming_llm_request};
26use crate::prelude::*;
27use crate::search_filter::SearchFilter;
28
29const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
30
31#[derive(Deserialize, Serialize, Debug)]
32pub struct ContentFilterResults {
33 pub hate: Option<ContentFilter>,
34 pub self_harm: Option<ContentFilter>,
35 pub sexual: Option<ContentFilter>,
36 pub violence: Option<ContentFilter>,
37}
38
39#[derive(Deserialize, Serialize, Debug)]
40pub struct ContentFilter {
41 pub filtered: bool,
42 pub severity: String,
43}
44
45#[derive(Deserialize, Serialize, Debug)]
46pub struct Choice {
47 pub content_filter_results: Option<ContentFilterResults>,
48 pub delta: Option<Delta>,
49 pub finish_reason: Option<String>,
50 pub index: i32,
51}
52
53#[derive(Deserialize, Serialize, Debug)]
54pub struct Delta {
55 pub content: Option<String>,
56 pub context: Option<DeltaContext>,
57}
58
59#[derive(Deserialize, Serialize, Debug)]
60pub struct DeltaContext {
61 pub citations: Vec<Citation>,
62}
63
64#[derive(Deserialize, Serialize, Debug)]
65pub struct Citation {
66 pub content: String,
67 pub title: String,
68 pub url: String,
69 pub filepath: String,
70}
71
72#[derive(Deserialize, Serialize, Debug)]
73pub struct ResponseChunk {
74 pub choices: Vec<Choice>,
75 pub created: u64,
76 pub id: String,
77 pub model: String,
78 pub object: String,
79 pub system_fingerprint: Option<String>,
80}
81
82impl From<ChatbotConversationMessage> for APIMessage {
83 fn from(message: ChatbotConversationMessage) -> Self {
84 APIMessage {
85 role: if message.is_from_chatbot {
86 MessageRole::Assistant
87 } else {
88 MessageRole::User
89 },
90 content: message.message.unwrap_or_default(),
91 }
92 }
93}
94
95#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
96pub struct ThinkingParams {
97 pub max_completion_tokens: Option<i32>,
98 pub verbosity: Option<VerbosityLevel>,
99 pub reasoning_effort: Option<ReasoningEffortLevel>,
100}
101
102#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
103pub struct NonThinkingParams {
104 pub max_tokens: Option<i32>,
105 pub temperature: Option<f32>,
106 pub top_p: Option<f32>,
107 pub frequency_penalty: Option<f32>,
108 pub presence_penalty: Option<f32>,
109}
110
111#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
112#[serde(untagged)]
113pub enum LLMRequestParams {
114 Thinking(ThinkingParams),
115 NonThinking(NonThinkingParams),
116 None,
117}
118
119#[derive(Serialize, Deserialize, Debug)]
120pub struct LLMRequest {
121 pub messages: Vec<APIMessage>,
122 #[serde(skip_serializing_if = "Vec::is_empty")]
123 pub data_sources: Vec<DataSource>,
124 #[serde(flatten)]
125 pub params: LLMRequestParams,
126 pub stop: Option<String>,
127}
128
129impl LLMRequest {
130 pub async fn build_and_insert_incoming_message_to_db(
131 conn: &mut PgConnection,
132 chatbot_configuration_id: Uuid,
133 conversation_id: Uuid,
134 message: &str,
135 app_config: &ApplicationConfiguration,
136 ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
137 let index_name = Url::parse(&app_config.base_url)?
138 .host_str()
139 .expect("BASE_URL must have a host")
140 .replace(".", "-");
141
142 let configuration =
143 models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
144
145 let conversation_messages =
146 models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
147 .await?;
148
149 let new_order_number = conversation_messages
150 .iter()
151 .map(|m| m.order_number)
152 .max()
153 .unwrap_or(0)
154 + 1;
155
156 let new_message = models::chatbot_conversation_messages::insert(
157 conn,
158 ChatbotConversationMessage {
159 id: Uuid::new_v4(),
160 created_at: Utc::now(),
161 updated_at: Utc::now(),
162 deleted_at: None,
163 conversation_id,
164 message: Some(message.to_string()),
165 is_from_chatbot: false,
166 message_is_complete: true,
167 used_tokens: estimate_tokens(message),
168 order_number: new_order_number,
169 },
170 )
171 .await?;
172
173 let mut api_chat_messages: Vec<APIMessage> =
174 conversation_messages.into_iter().map(Into::into).collect();
175
176 api_chat_messages.push(new_message.clone().into());
177
178 api_chat_messages.insert(
179 0,
180 APIMessage {
181 role: MessageRole::System,
182 content: configuration.prompt.clone(),
183 },
184 );
185
186 let data_sources = if configuration.use_azure_search {
187 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
188 anyhow::anyhow!("Azure configuration is missing from the application configuration")
189 })?;
190
191 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
192 anyhow::anyhow!(
193 "Azure search configuration is missing from the Azure configuration"
194 )
195 })?;
196
197 let query_type = if configuration.use_semantic_reranking {
198 "vector_semantic_hybrid"
199 } else {
200 "vector_simple_hybrid"
201 };
202
203 vec![DataSource {
204 data_type: "azure_search".to_string(),
205 parameters: DataSourceParameters {
206 endpoint: search_config.search_endpoint.to_string(),
207 authentication: DataSourceParametersAuthentication {
208 auth_type: "api_key".to_string(),
209 key: search_config.search_api_key.clone(),
210 },
211 index_name,
212 query_type: query_type.to_string(),
213 semantic_configuration: "default".to_string(),
214 embedding_dependency: EmbeddingDependency {
215 dep_type: "deployment_name".to_string(),
216 deployment_name: search_config.vectorizer_deployment_id.clone(),
217 },
218 in_scope: false,
219 top_n_documents: 15,
220 strictness: 3,
221 filter: Some(
222 SearchFilter::eq("course_id", configuration.course_id.to_string())
223 .to_odata()?,
224 ),
225 fields_mapping: FieldsMapping {
226 content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
227 content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
228 filepath_field: "filepath".to_string(),
229 title_field: "title".to_string(),
230 url_field: "url".to_string(),
231 vector_fields: vec!["text_vector".to_string()],
232 },
233 },
234 }]
235 } else {
236 Vec::new()
237 };
238
239 let serialized_messages = serde_json::to_string(&api_chat_messages)?;
240 let request_estimated_tokens = estimate_tokens(&serialized_messages);
241
242 let params = if configuration.thinking_model {
243 LLMRequestParams::Thinking(ThinkingParams {
244 max_completion_tokens: Some(configuration.max_completion_tokens),
245 reasoning_effort: Some(configuration.reasoning_effort),
246 verbosity: Some(configuration.verbosity),
247 })
248 } else {
249 LLMRequestParams::NonThinking(NonThinkingParams {
250 max_tokens: Some(configuration.response_max_tokens),
251 temperature: Some(configuration.temperature),
252 top_p: Some(configuration.top_p),
253 frequency_penalty: Some(configuration.frequency_penalty),
254 presence_penalty: Some(configuration.presence_penalty),
255 })
256 };
257
258 Ok((
259 Self {
260 messages: api_chat_messages,
261 data_sources,
262 params,
263 stop: None,
264 },
265 new_message,
266 request_estimated_tokens,
267 ))
268 }
269}
270
271#[derive(Serialize, Deserialize, Debug)]
272pub struct DataSource {
273 #[serde(rename = "type")]
274 pub data_type: String,
275 pub parameters: DataSourceParameters,
276}
277
278#[derive(Serialize, Deserialize, Debug)]
279pub struct DataSourceParameters {
280 pub endpoint: String,
281 pub authentication: DataSourceParametersAuthentication,
282 pub index_name: String,
283 pub query_type: String,
284 pub embedding_dependency: EmbeddingDependency,
285 pub in_scope: bool,
286 pub top_n_documents: i32,
287 pub strictness: i32,
288 #[serde(skip_serializing_if = "Option::is_none")]
289 pub filter: Option<String>,
290 pub fields_mapping: FieldsMapping,
291 pub semantic_configuration: String,
292}
293
294#[derive(Serialize, Deserialize, Debug)]
295pub struct DataSourceParametersAuthentication {
296 #[serde(rename = "type")]
297 pub auth_type: String,
298 pub key: String,
299}
300
301#[derive(Serialize, Deserialize, Debug)]
302pub struct EmbeddingDependency {
303 #[serde(rename = "type")]
304 pub dep_type: String,
305 pub deployment_name: String,
306}
307
308#[derive(Serialize, Deserialize, Debug)]
309pub struct FieldsMapping {
310 pub content_fields_separator: String,
311 pub content_fields: Vec<String>,
312 pub filepath_field: String,
313 pub title_field: String,
314 pub url_field: String,
315 pub vector_fields: Vec<String>,
316}
317
318#[derive(Serialize, Deserialize, Debug)]
319pub struct ChatResponse {
320 pub text: String,
321}
322
323#[pin_project]
325struct GuardedStream<S> {
326 guard: RequestCancelledGuard,
327 #[pin]
328 stream: S,
329}
330
331impl<S> GuardedStream<S> {
332 fn new(guard: RequestCancelledGuard, stream: S) -> Self {
333 Self { guard, stream }
334 }
335}
336
337impl<S> Stream for GuardedStream<S>
338where
339 S: Stream<Item = anyhow::Result<Bytes>> + Send,
340{
341 type Item = S::Item;
342
343 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
344 let this = self.project();
345 this.stream.poll_next(cx)
346 }
347}
348
349struct RequestCancelledGuard {
350 response_message_id: Uuid,
351 received_string: Arc<Mutex<Vec<String>>>,
352 pool: PgPool,
353 done: Arc<AtomicBool>,
354 request_estimated_tokens: i32,
355}
356
357impl Drop for RequestCancelledGuard {
358 fn drop(&mut self) {
359 if self.done.load(atomic::Ordering::Relaxed) {
360 return;
361 }
362 warn!("Request was not cancelled. Cleaning up.");
363 let response_message_id = self.response_message_id;
364 let received_string = self.received_string.clone();
365 let pool = self.pool.clone();
366 let request_estimated_tokens = self.request_estimated_tokens;
367 tokio::spawn(async move {
368 info!("Verifying the received message has been handled");
369 let mut conn = pool.acquire().await.expect("Could not acquire connection");
370 let full_response_text = received_string.lock().await;
371 if full_response_text.is_empty() {
372 info!("No response received. Deleting the response message");
373 models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
374 .await
375 .expect("Could not delete response message");
376 return;
377 }
378 info!("Response received but not completed. Saving the text received so far.");
379 let full_response_as_string = full_response_text.join("");
380 let estimated_cost = estimate_tokens(&full_response_as_string);
381 info!(
382 "End of chatbot response stream. Estimated cost: {}. Response: {}",
383 estimated_cost, full_response_as_string
384 );
385
386 models::chatbot_conversation_messages::update(
388 &mut conn,
389 response_message_id,
390 &full_response_as_string,
391 true,
392 request_estimated_tokens + estimated_cost,
393 )
394 .await
395 .expect("Could not update response message");
396 });
397 }
398}
399
400pub async fn send_chat_request_and_parse_stream(
401 conn: &mut PgConnection,
402 pool: PgPool,
403 app_config: &ApplicationConfiguration,
404 chatbot_configuration_id: Uuid,
405 conversation_id: Uuid,
406 message: &str,
407) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
408 let (chat_request, new_message, request_estimated_tokens) =
409 LLMRequest::build_and_insert_incoming_message_to_db(
410 conn,
411 chatbot_configuration_id,
412 conversation_id,
413 message,
414 app_config,
415 )
416 .await?;
417
418 let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
419 conn,
420 chatbot_configuration_id,
421 )
422 .await?;
423
424 let full_response_text = Arc::new(Mutex::new(Vec::new()));
425 let done = Arc::new(AtomicBool::new(false));
426
427 let response_order_number = new_message.order_number + 1;
428
429 let response_message = models::chatbot_conversation_messages::insert(
430 conn,
431 ChatbotConversationMessage {
432 id: Uuid::new_v4(),
433 created_at: Utc::now(),
434 updated_at: Utc::now(),
435 deleted_at: None,
436 conversation_id,
437 message: None,
438 is_from_chatbot: true,
439 message_is_complete: false,
440 used_tokens: request_estimated_tokens,
441 order_number: response_order_number,
442 },
443 )
444 .await?;
445
446 let guard = RequestCancelledGuard {
448 response_message_id: response_message.id,
449 received_string: full_response_text.clone(),
450 pool: pool.clone(),
451 done: done.clone(),
452 request_estimated_tokens,
453 };
454
455 let response =
456 make_streaming_llm_request(chat_request, &model.deployment_name, app_config).await?;
457
458 info!("Receiving chat response with {:?}", response.version());
459
460 if !response.status().is_success() {
461 let status = response.status();
462 let error_message = response.text().await?;
463 return Err(anyhow::anyhow!(
464 "Failed to send chat request. Status: {}. Error: {}",
465 status,
466 error_message
467 ));
468 }
469
470 let stream = response.bytes_stream().map_err(std::io::Error::other);
471 let reader = StreamReader::new(stream);
472 let mut lines = reader.lines();
473
474 let response_stream = async_stream::try_stream! {
475 while let Some(line) = lines.next_line().await? {
476 if !line.starts_with("data: ") {
477 continue;
478 }
479 let json_str = line.trim_start_matches("data: ");
480
481 let mut full_response_text = full_response_text.lock().await;
482 if json_str.trim() == "[DONE]" {
483 let full_response_as_string = full_response_text.join("");
484 let estimated_cost = estimate_tokens(&full_response_as_string);
485 info!(
486 "End of chatbot response stream. Estimated cost: {}. Response: {}",
487 estimated_cost, full_response_as_string
488 );
489 done.store(true, atomic::Ordering::Relaxed);
490 let mut conn = pool.acquire().await?;
491 models::chatbot_conversation_messages::update(
492 &mut conn,
493 response_message.id,
494 &full_response_as_string,
495 true,
496 request_estimated_tokens + estimated_cost,
497 ).await?;
498 break;
499 }
500 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
501 anyhow::anyhow!("Failed to parse response chunk: {}", e)
502 })?;
503
504 for choice in &response_chunk.choices {
505 if let Some(delta) = &choice.delta {
506 if let Some(content) = &delta.content {
507 full_response_text.push(content.clone());
508 let response = ChatResponse { text: content.clone() };
509 let response_as_string = serde_json::to_string(&response)?;
510 yield Bytes::from(response_as_string);
511 yield Bytes::from("\n");
512 }
513 if let Some(context) = &delta.context {
514 let citation_message_id = response_message.id;
515 let mut conn = pool.acquire().await?;
516 for (idx, cit) in context.citations.iter().enumerate() {
517 let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
518 let split = content.split_once(CONTENT_FIELD_SEPARATOR);
519 if split.is_none() {
520 error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
521 }
522 let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
523
524 let decoded_title = url_decode(&cit.title)?;
528 let decoded_url = url_decode(&cit.url)?;
529
530 let mut page_path = PathBuf::from(&cit.filepath);
531 page_path.set_extension("");
532 let page_id_str = page_path.file_name();
533 let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
534 let course_material_chapter_number = if let Some(id) = page_id {
535 let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
536 chapter.map(|c| c.chapter_number)
537 } else {
538 None
539 };
540
541 models::chatbot_conversation_messages_citations::insert(
542 &mut conn, ChatbotConversationMessageCitation {
543 id: Uuid::new_v4(),
544 created_at: Utc::now(),
545 updated_at: Utc::now(),
546 deleted_at: None,
547 conversation_message_id: citation_message_id,
548 conversation_id,
549 course_material_chapter_number,
550 title: decoded_title,
551 content: cleaned_content,
552 document_url: decoded_url,
553 citation_number: (idx+1) as i32,
554 }
555 ).await?;
556 }
557 }
558
559 }
560 }
561 }
562
563 if !done.load(atomic::Ordering::Relaxed) {
564 Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
565 }
566 };
567
568 let guarded_stream = GuardedStream::new(guard, response_stream);
571
572 Ok(Box::pin(guarded_stream))
574}