1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4 Arc,
5 atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use chrono::Utc;
11use futures::{Stream, TryStreamExt};
12use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
13use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
14use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
15use headless_lms_utils::ApplicationConfiguration;
16use pin_project::pin_project;
17use serde::{Deserialize, Serialize};
18use sqlx::PgPool;
19use tokio::{io::AsyncBufReadExt, sync::Mutex};
20use tokio_util::io::StreamReader;
21use url::Url;
22
23use crate::llm_utils::{APIMessage, MessageRole, estimate_tokens, make_streaming_llm_request};
24use crate::prelude::*;
25use crate::search_filter::SearchFilter;
26
27const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
28
29#[derive(Deserialize, Serialize, Debug)]
30pub struct ContentFilterResults {
31 pub hate: Option<ContentFilter>,
32 pub self_harm: Option<ContentFilter>,
33 pub sexual: Option<ContentFilter>,
34 pub violence: Option<ContentFilter>,
35}
36
37#[derive(Deserialize, Serialize, Debug)]
38pub struct ContentFilter {
39 pub filtered: bool,
40 pub severity: String,
41}
42
43#[derive(Deserialize, Serialize, Debug)]
44pub struct Choice {
45 pub content_filter_results: Option<ContentFilterResults>,
46 pub delta: Option<Delta>,
47 pub finish_reason: Option<String>,
48 pub index: i32,
49}
50
51#[derive(Deserialize, Serialize, Debug)]
52pub struct Delta {
53 pub content: Option<String>,
54 pub context: Option<DeltaContext>,
55}
56
57#[derive(Deserialize, Serialize, Debug)]
58pub struct DeltaContext {
59 pub citations: Vec<Citation>,
60}
61
62#[derive(Deserialize, Serialize, Debug)]
63pub struct Citation {
64 pub content: String,
65 pub title: String,
66 pub url: String,
67 pub filepath: String,
68}
69
70#[derive(Deserialize, Serialize, Debug)]
71pub struct ResponseChunk {
72 pub choices: Vec<Choice>,
73 pub created: u64,
74 pub id: String,
75 pub model: String,
76 pub object: String,
77 pub system_fingerprint: Option<String>,
78}
79
80impl From<ChatbotConversationMessage> for APIMessage {
81 fn from(message: ChatbotConversationMessage) -> Self {
82 APIMessage {
83 role: if message.is_from_chatbot {
84 MessageRole::Assistant
85 } else {
86 MessageRole::User
87 },
88 content: message.message.unwrap_or_default(),
89 }
90 }
91}
92
93#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
94pub struct ThinkingParams {
95 pub max_completion_tokens: Option<i32>,
96 pub verbosity: Option<VerbosityLevel>,
97 pub reasoning_effort: Option<ReasoningEffortLevel>,
98}
99
100#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
101pub struct NonThinkingParams {
102 pub max_tokens: Option<i32>,
103 pub temperature: Option<f32>,
104 pub top_p: Option<f32>,
105 pub frequency_penalty: Option<f32>,
106 pub presence_penalty: Option<f32>,
107}
108
109#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
110#[serde(untagged)]
111pub enum LLMRequestParams {
112 Thinking(ThinkingParams),
113 NonThinking(NonThinkingParams),
114 None,
115}
116
117#[derive(Serialize, Deserialize, Debug)]
118pub struct LLMRequest {
119 pub messages: Vec<APIMessage>,
120 #[serde(skip_serializing_if = "Vec::is_empty")]
121 pub data_sources: Vec<DataSource>,
122 #[serde(flatten)]
123 pub params: LLMRequestParams,
124 pub stop: Option<String>,
125}
126
127impl LLMRequest {
128 pub async fn build_and_insert_incoming_message_to_db(
129 conn: &mut PgConnection,
130 chatbot_configuration_id: Uuid,
131 conversation_id: Uuid,
132 message: &str,
133 app_config: &ApplicationConfiguration,
134 ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
135 let index_name = Url::parse(&app_config.base_url)?
136 .host_str()
137 .expect("BASE_URL must have a host")
138 .replace(".", "-");
139
140 let configuration =
141 models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
142
143 let conversation_messages =
144 models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
145 .await?;
146
147 let new_order_number = conversation_messages
148 .iter()
149 .map(|m| m.order_number)
150 .max()
151 .unwrap_or(0)
152 + 1;
153
154 let new_message = models::chatbot_conversation_messages::insert(
155 conn,
156 ChatbotConversationMessage {
157 id: Uuid::new_v4(),
158 created_at: Utc::now(),
159 updated_at: Utc::now(),
160 deleted_at: None,
161 conversation_id,
162 message: Some(message.to_string()),
163 is_from_chatbot: false,
164 message_is_complete: true,
165 used_tokens: estimate_tokens(message),
166 order_number: new_order_number,
167 },
168 )
169 .await?;
170
171 let mut api_chat_messages: Vec<APIMessage> =
172 conversation_messages.into_iter().map(Into::into).collect();
173
174 api_chat_messages.push(new_message.clone().into());
175
176 api_chat_messages.insert(
177 0,
178 APIMessage {
179 role: MessageRole::System,
180 content: configuration.prompt.clone(),
181 },
182 );
183
184 let data_sources = if configuration.use_azure_search {
185 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
186 anyhow::anyhow!("Azure configuration is missing from the application configuration")
187 })?;
188
189 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
190 anyhow::anyhow!(
191 "Azure search configuration is missing from the Azure configuration"
192 )
193 })?;
194
195 let query_type = if configuration.use_semantic_reranking {
196 "vector_semantic_hybrid"
197 } else {
198 "vector_simple_hybrid"
199 };
200
201 vec![DataSource {
202 data_type: "azure_search".to_string(),
203 parameters: DataSourceParameters {
204 endpoint: search_config.search_endpoint.to_string(),
205 authentication: DataSourceParametersAuthentication {
206 auth_type: "api_key".to_string(),
207 key: search_config.search_api_key.clone(),
208 },
209 index_name,
210 query_type: query_type.to_string(),
211 semantic_configuration: "default".to_string(),
212 embedding_dependency: EmbeddingDependency {
213 dep_type: "deployment_name".to_string(),
214 deployment_name: search_config.vectorizer_deployment_id.clone(),
215 },
216 in_scope: false,
217 top_n_documents: 15,
218 strictness: 3,
219 filter: Some(
220 SearchFilter::eq("course_id", configuration.course_id.to_string())
221 .to_odata()?,
222 ),
223 fields_mapping: FieldsMapping {
224 content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
225 content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
226 filepath_field: "filepath".to_string(),
227 title_field: "title".to_string(),
228 url_field: "url".to_string(),
229 vector_fields: vec!["text_vector".to_string()],
230 },
231 },
232 }]
233 } else {
234 Vec::new()
235 };
236
237 let serialized_messages = serde_json::to_string(&api_chat_messages)?;
238 let request_estimated_tokens = estimate_tokens(&serialized_messages);
239
240 let params = if configuration.thinking_model {
241 LLMRequestParams::Thinking(ThinkingParams {
242 max_completion_tokens: Some(configuration.max_completion_tokens),
243 reasoning_effort: Some(configuration.reasoning_effort),
244 verbosity: Some(configuration.verbosity),
245 })
246 } else {
247 LLMRequestParams::NonThinking(NonThinkingParams {
248 max_tokens: Some(configuration.response_max_tokens),
249 temperature: Some(configuration.temperature),
250 top_p: Some(configuration.top_p),
251 frequency_penalty: Some(configuration.frequency_penalty),
252 presence_penalty: Some(configuration.presence_penalty),
253 })
254 };
255
256 Ok((
257 Self {
258 messages: api_chat_messages,
259 data_sources,
260 params,
261 stop: None,
262 },
263 new_message,
264 request_estimated_tokens,
265 ))
266 }
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct DataSource {
271 #[serde(rename = "type")]
272 pub data_type: String,
273 pub parameters: DataSourceParameters,
274}
275
276#[derive(Serialize, Deserialize, Debug)]
277pub struct DataSourceParameters {
278 pub endpoint: String,
279 pub authentication: DataSourceParametersAuthentication,
280 pub index_name: String,
281 pub query_type: String,
282 pub embedding_dependency: EmbeddingDependency,
283 pub in_scope: bool,
284 pub top_n_documents: i32,
285 pub strictness: i32,
286 #[serde(skip_serializing_if = "Option::is_none")]
287 pub filter: Option<String>,
288 pub fields_mapping: FieldsMapping,
289 pub semantic_configuration: String,
290}
291
292#[derive(Serialize, Deserialize, Debug)]
293pub struct DataSourceParametersAuthentication {
294 #[serde(rename = "type")]
295 pub auth_type: String,
296 pub key: String,
297}
298
299#[derive(Serialize, Deserialize, Debug)]
300pub struct EmbeddingDependency {
301 #[serde(rename = "type")]
302 pub dep_type: String,
303 pub deployment_name: String,
304}
305
306#[derive(Serialize, Deserialize, Debug)]
307pub struct FieldsMapping {
308 pub content_fields_separator: String,
309 pub content_fields: Vec<String>,
310 pub filepath_field: String,
311 pub title_field: String,
312 pub url_field: String,
313 pub vector_fields: Vec<String>,
314}
315
316#[derive(Serialize, Deserialize, Debug)]
317pub struct ChatResponse {
318 pub text: String,
319}
320
321#[pin_project]
323struct GuardedStream<S> {
324 guard: RequestCancelledGuard,
325 #[pin]
326 stream: S,
327}
328
329impl<S> GuardedStream<S> {
330 fn new(guard: RequestCancelledGuard, stream: S) -> Self {
331 Self { guard, stream }
332 }
333}
334
335impl<S> Stream for GuardedStream<S>
336where
337 S: Stream<Item = anyhow::Result<Bytes>> + Send,
338{
339 type Item = S::Item;
340
341 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
342 let this = self.project();
343 this.stream.poll_next(cx)
344 }
345}
346
347struct RequestCancelledGuard {
348 response_message_id: Uuid,
349 received_string: Arc<Mutex<Vec<String>>>,
350 pool: PgPool,
351 done: Arc<AtomicBool>,
352 request_estimated_tokens: i32,
353}
354
355impl Drop for RequestCancelledGuard {
356 fn drop(&mut self) {
357 if self.done.load(atomic::Ordering::Relaxed) {
358 return;
359 }
360 warn!("Request was not cancelled. Cleaning up.");
361 let response_message_id = self.response_message_id;
362 let received_string = self.received_string.clone();
363 let pool = self.pool.clone();
364 let request_estimated_tokens = self.request_estimated_tokens;
365 tokio::spawn(async move {
366 info!("Verifying the received message has been handled");
367 let mut conn = pool.acquire().await.expect("Could not acquire connection");
368 let full_response_text = received_string.lock().await;
369 if full_response_text.is_empty() {
370 info!("No response received. Deleting the response message");
371 models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
372 .await
373 .expect("Could not delete response message");
374 return;
375 }
376 info!("Response received but not completed. Saving the text received so far.");
377 let full_response_as_string = full_response_text.join("");
378 let estimated_cost = estimate_tokens(&full_response_as_string);
379 info!(
380 "End of chatbot response stream. Estimated cost: {}. Response: {}",
381 estimated_cost, full_response_as_string
382 );
383
384 models::chatbot_conversation_messages::update(
386 &mut conn,
387 response_message_id,
388 &full_response_as_string,
389 true,
390 request_estimated_tokens + estimated_cost,
391 )
392 .await
393 .expect("Could not update response message");
394 });
395 }
396}
397
398pub async fn send_chat_request_and_parse_stream(
399 conn: &mut PgConnection,
400 pool: PgPool,
401 app_config: &ApplicationConfiguration,
402 chatbot_configuration_id: Uuid,
403 conversation_id: Uuid,
404 message: &str,
405) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
406 let (chat_request, new_message, request_estimated_tokens) =
407 LLMRequest::build_and_insert_incoming_message_to_db(
408 conn,
409 chatbot_configuration_id,
410 conversation_id,
411 message,
412 app_config,
413 )
414 .await?;
415
416 let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
417 conn,
418 chatbot_configuration_id,
419 )
420 .await?;
421
422 let full_response_text = Arc::new(Mutex::new(Vec::new()));
423 let done = Arc::new(AtomicBool::new(false));
424
425 let response_order_number = new_message.order_number + 1;
426
427 let response_message = models::chatbot_conversation_messages::insert(
428 conn,
429 ChatbotConversationMessage {
430 id: Uuid::new_v4(),
431 created_at: Utc::now(),
432 updated_at: Utc::now(),
433 deleted_at: None,
434 conversation_id,
435 message: None,
436 is_from_chatbot: true,
437 message_is_complete: false,
438 used_tokens: request_estimated_tokens,
439 order_number: response_order_number,
440 },
441 )
442 .await?;
443
444 let guard = RequestCancelledGuard {
446 response_message_id: response_message.id,
447 received_string: full_response_text.clone(),
448 pool: pool.clone(),
449 done: done.clone(),
450 request_estimated_tokens,
451 };
452
453 let response =
454 make_streaming_llm_request(chat_request, &model.deployment_name, app_config).await?;
455
456 info!("Receiving chat response with {:?}", response.version());
457
458 if !response.status().is_success() {
459 let status = response.status();
460 let error_message = response.text().await?;
461 return Err(anyhow::anyhow!(
462 "Failed to send chat request. Status: {}. Error: {}",
463 status,
464 error_message
465 ));
466 }
467
468 let stream = response.bytes_stream().map_err(std::io::Error::other);
469 let reader = StreamReader::new(stream);
470 let mut lines = reader.lines();
471
472 let response_stream = async_stream::try_stream! {
473 while let Some(line) = lines.next_line().await? {
474 if !line.starts_with("data: ") {
475 continue;
476 }
477 let json_str = line.trim_start_matches("data: ");
478
479 let mut full_response_text = full_response_text.lock().await;
480 if json_str.trim() == "[DONE]" {
481 let full_response_as_string = full_response_text.join("");
482 let estimated_cost = estimate_tokens(&full_response_as_string);
483 info!(
484 "End of chatbot response stream. Estimated cost: {}. Response: {}",
485 estimated_cost, full_response_as_string
486 );
487 done.store(true, atomic::Ordering::Relaxed);
488 let mut conn = pool.acquire().await?;
489 models::chatbot_conversation_messages::update(
490 &mut conn,
491 response_message.id,
492 &full_response_as_string,
493 true,
494 request_estimated_tokens + estimated_cost,
495 ).await?;
496 break;
497 }
498 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
499 anyhow::anyhow!("Failed to parse response chunk: {}", e)
500 })?;
501
502 for choice in &response_chunk.choices {
503 if let Some(delta) = &choice.delta {
504 if let Some(content) = &delta.content {
505 full_response_text.push(content.clone());
506 let response = ChatResponse { text: content.clone() };
507 let response_as_string = serde_json::to_string(&response)?;
508 yield Bytes::from(response_as_string);
509 yield Bytes::from("\n");
510 }
511 if let Some(context) = &delta.context {
512 let citation_message_id = response_message.id;
513 let mut conn = pool.acquire().await?;
514 for (idx, cit) in context.citations.iter().enumerate() {
515 let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
516 let split = content.split_once(CONTENT_FIELD_SEPARATOR);
517 if split.is_none() {
518 error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
519 }
520 let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
521
522 let document_url = cit.url.clone();
523 let mut page_path = PathBuf::from(&cit.filepath);
524 page_path.set_extension("");
525 let page_id_str = page_path.file_name();
526 let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
527 let course_material_chapter_number = if let Some(id) = page_id {
528 let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
529 chapter.map(|c| c.chapter_number)
530 } else {
531 None
532 };
533
534 models::chatbot_conversation_messages_citations::insert(
535 &mut conn, ChatbotConversationMessageCitation {
536 id: Uuid::new_v4(),
537 created_at: Utc::now(),
538 updated_at: Utc::now(),
539 deleted_at: None,
540 conversation_message_id: citation_message_id,
541 conversation_id,
542 course_material_chapter_number,
543 title: cit.title.clone(),
544 content: cleaned_content,
545 document_url,
546 citation_number: (idx+1) as i32,
547 }
548 ).await?;
549 }
550 }
551
552 }
553 }
554 }
555
556 if !done.load(atomic::Ordering::Relaxed) {
557 Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
558 }
559 };
560
561 let guarded_stream = GuardedStream::new(guard, response_stream);
564
565 Ok(Box::pin(guarded_stream))
567}