1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4 Arc,
5 atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use chrono::Utc;
11use futures::{Stream, TryStreamExt};
12use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
13use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
14use headless_lms_utils::{ApplicationConfiguration, http::REQWEST_CLIENT};
15use pin_project::pin_project;
16use serde::{Deserialize, Serialize};
17use sqlx::PgPool;
18use tokio::{io::AsyncBufReadExt, sync::Mutex};
19use tokio_util::io::StreamReader;
20use url::Url;
21
22use crate::llm_utils::{LLM_API_VERSION, build_llm_headers, estimate_tokens};
23use crate::prelude::*;
24use crate::search_filter::SearchFilter;
25
26#[derive(Deserialize, Serialize, Debug)]
27pub struct ContentFilterResults {
28 pub hate: Option<ContentFilter>,
29 pub self_harm: Option<ContentFilter>,
30 pub sexual: Option<ContentFilter>,
31 pub violence: Option<ContentFilter>,
32}
33
34#[derive(Deserialize, Serialize, Debug)]
35pub struct ContentFilter {
36 pub filtered: bool,
37 pub severity: String,
38}
39
40#[derive(Deserialize, Serialize, Debug)]
41pub struct Choice {
42 pub content_filter_results: Option<ContentFilterResults>,
43 pub delta: Option<Delta>,
44 pub finish_reason: Option<String>,
45 pub index: i32,
46}
47
48#[derive(Deserialize, Serialize, Debug)]
49pub struct Delta {
50 pub content: Option<String>,
51 pub context: Option<DeltaContext>,
52}
53
54#[derive(Deserialize, Serialize, Debug)]
55pub struct DeltaContext {
56 pub citations: Vec<Citation>,
57}
58
59#[derive(Deserialize, Serialize, Debug)]
60pub struct Citation {
61 pub content: String,
62 pub title: String,
63 pub url: String,
64 pub filepath: String,
65}
66
67#[derive(Deserialize, Serialize, Debug)]
68pub struct ResponseChunk {
69 pub choices: Vec<Choice>,
70 pub created: u64,
71 pub id: String,
72 pub model: String,
73 pub object: String,
74 pub system_fingerprint: Option<String>,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
79pub struct ApiChatMessage {
80 pub role: String,
81 pub content: String,
82}
83
84impl From<ChatbotConversationMessage> for ApiChatMessage {
85 fn from(message: ChatbotConversationMessage) -> Self {
86 ApiChatMessage {
87 role: if message.is_from_chatbot {
88 "assistant".to_string()
89 } else {
90 "user".to_string()
91 },
92 content: message.message.unwrap_or_default(),
93 }
94 }
95}
96
97#[derive(Serialize, Deserialize, Debug)]
98pub struct ChatRequest {
99 pub messages: Vec<ApiChatMessage>,
100 #[serde(skip_serializing_if = "Vec::is_empty")]
101 pub data_sources: Vec<DataSource>,
102 pub temperature: f32,
103 pub top_p: f32,
104 pub frequency_penalty: f32,
105 pub presence_penalty: f32,
106 pub max_tokens: i32,
107 pub stop: Option<String>,
108 pub stream: bool,
109}
110
111impl ChatRequest {
112 pub async fn build_and_insert_incoming_message_to_db(
113 conn: &mut PgConnection,
114 chatbot_configuration_id: Uuid,
115 conversation_id: Uuid,
116 message: &str,
117 app_config: &ApplicationConfiguration,
118 ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
119 let index_name = Url::parse(&app_config.base_url)?
120 .host_str()
121 .expect("BASE_URL must have a host")
122 .replace(".", "-");
123
124 let configuration =
125 models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
126
127 let conversation_messages =
128 models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
129 .await?;
130
131 let new_order_number = conversation_messages
132 .iter()
133 .map(|m| m.order_number)
134 .max()
135 .unwrap_or(0)
136 + 1;
137
138 let new_message = models::chatbot_conversation_messages::insert(
139 conn,
140 ChatbotConversationMessage {
141 id: Uuid::new_v4(),
142 created_at: Utc::now(),
143 updated_at: Utc::now(),
144 deleted_at: None,
145 conversation_id,
146 message: Some(message.to_string()),
147 is_from_chatbot: false,
148 message_is_complete: true,
149 used_tokens: estimate_tokens(message),
150 order_number: new_order_number,
151 },
152 )
153 .await?;
154
155 let mut api_chat_messages: Vec<ApiChatMessage> =
156 conversation_messages.into_iter().map(Into::into).collect();
157
158 api_chat_messages.push(new_message.clone().into());
159
160 api_chat_messages.insert(
161 0,
162 ApiChatMessage {
163 role: "system".to_string(),
164 content: configuration.prompt.clone(),
165 },
166 );
167
168 let data_sources = if configuration.use_azure_search {
169 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
170 anyhow::anyhow!("Azure configuration is missing from the application configuration")
171 })?;
172
173 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
174 anyhow::anyhow!(
175 "Azure search configuration is missing from the Azure configuration"
176 )
177 })?;
178
179 let query_type = if configuration.use_semantic_reranking {
180 "vector_semantic_hybrid"
181 } else {
182 "vector_simple_hybrid"
183 };
184
185 vec![DataSource {
186 data_type: "azure_search".to_string(),
187 parameters: DataSourceParameters {
188 endpoint: search_config.search_endpoint.to_string(),
189 authentication: DataSourceParametersAuthentication {
190 auth_type: "api_key".to_string(),
191 key: search_config.search_api_key.clone(),
192 },
193 index_name,
194 query_type: query_type.to_string(),
195 semantic_configuration: "default".to_string(),
196 embedding_dependency: EmbeddingDependency {
197 dep_type: "deployment_name".to_string(),
198 deployment_name: search_config.vectorizer_deployment_id.clone(),
199 },
200 in_scope: false,
201 top_n_documents: 5,
202 strictness: 3,
203 filter: Some(
204 SearchFilter::eq("course_id", configuration.course_id.to_string())
205 .to_odata()?,
206 ),
207 fields_mapping: FieldsMapping {
208 content_fields_separator: ",".to_string(),
209 content_fields: vec!["chunk".to_string()],
210 filepath_field: "filepath".to_string(),
211 title_field: "title".to_string(),
212 url_field: "url".to_string(),
213 vector_fields: vec!["text_vector".to_string()],
214 },
215 },
216 }]
217 } else {
218 Vec::new()
219 };
220
221 let serialized_messages = serde_json::to_string(&api_chat_messages)?;
222 let request_estimated_tokens = estimate_tokens(&serialized_messages);
223
224 Ok((
225 Self {
226 messages: api_chat_messages,
227 data_sources,
228 temperature: configuration.temperature,
229 top_p: configuration.top_p,
230 frequency_penalty: configuration.frequency_penalty,
231 presence_penalty: configuration.presence_penalty,
232 max_tokens: configuration.response_max_tokens,
233 stop: None,
234 stream: true,
235 },
236 new_message,
237 request_estimated_tokens,
238 ))
239 }
240}
241
242#[derive(Serialize, Deserialize, Debug)]
243pub struct DataSource {
244 #[serde(rename = "type")]
245 pub data_type: String,
246 pub parameters: DataSourceParameters,
247}
248
249#[derive(Serialize, Deserialize, Debug)]
250pub struct DataSourceParameters {
251 pub endpoint: String,
252 pub authentication: DataSourceParametersAuthentication,
253 pub index_name: String,
254 pub query_type: String,
255 pub embedding_dependency: EmbeddingDependency,
256 pub in_scope: bool,
257 pub top_n_documents: i32,
258 pub strictness: i32,
259 #[serde(skip_serializing_if = "Option::is_none")]
260 pub filter: Option<String>,
261 pub fields_mapping: FieldsMapping,
262 pub semantic_configuration: String,
263}
264
265#[derive(Serialize, Deserialize, Debug)]
266pub struct DataSourceParametersAuthentication {
267 #[serde(rename = "type")]
268 pub auth_type: String,
269 pub key: String,
270}
271
272#[derive(Serialize, Deserialize, Debug)]
273pub struct EmbeddingDependency {
274 #[serde(rename = "type")]
275 pub dep_type: String,
276 pub deployment_name: String,
277}
278
279#[derive(Serialize, Deserialize, Debug)]
280pub struct FieldsMapping {
281 pub content_fields_separator: String,
282 pub content_fields: Vec<String>,
283 pub filepath_field: String,
284 pub title_field: String,
285 pub url_field: String,
286 pub vector_fields: Vec<String>,
287}
288
289#[derive(Serialize, Deserialize, Debug)]
290pub struct ChatResponse {
291 pub text: String,
292}
293
294#[pin_project]
296struct GuardedStream<S> {
297 guard: RequestCancelledGuard,
298 #[pin]
299 stream: S,
300}
301
302impl<S> GuardedStream<S> {
303 fn new(guard: RequestCancelledGuard, stream: S) -> Self {
304 Self { guard, stream }
305 }
306}
307
308impl<S> Stream for GuardedStream<S>
309where
310 S: Stream<Item = anyhow::Result<Bytes>> + Send,
311{
312 type Item = S::Item;
313
314 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
315 let this = self.project();
316 this.stream.poll_next(cx)
317 }
318}
319
320struct RequestCancelledGuard {
321 response_message_id: Uuid,
322 received_string: Arc<Mutex<Vec<String>>>,
323 pool: PgPool,
324 done: Arc<AtomicBool>,
325 request_estimated_tokens: i32,
326}
327
328impl Drop for RequestCancelledGuard {
329 fn drop(&mut self) {
330 if self.done.load(atomic::Ordering::Relaxed) {
331 return;
332 }
333 warn!("Request was not cancelled. Cleaning up.");
334 let response_message_id = self.response_message_id;
335 let received_string = self.received_string.clone();
336 let pool = self.pool.clone();
337 let request_estimated_tokens = self.request_estimated_tokens;
338 tokio::spawn(async move {
339 info!("Verifying the received message has been handled");
340 let mut conn = pool.acquire().await.expect("Could not acquire connection");
341 let full_response_text = received_string.lock().await;
342 if full_response_text.is_empty() {
343 info!("No response received. Deleting the response message");
344 models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
345 .await
346 .expect("Could not delete response message");
347 return;
348 }
349 info!("Response received but not completed. Saving the text received so far.");
350 let full_response_as_string = full_response_text.join("");
351 let estimated_cost = estimate_tokens(&full_response_as_string);
352 info!(
353 "End of chatbot response stream. Estimated cost: {}. Response: {}",
354 estimated_cost, full_response_as_string
355 );
356
357 models::chatbot_conversation_messages::update(
359 &mut conn,
360 response_message_id,
361 &full_response_as_string,
362 true,
363 request_estimated_tokens + estimated_cost,
364 )
365 .await
366 .expect("Could not update response message");
367 });
368 }
369}
370
371pub async fn send_chat_request_and_parse_stream(
372 conn: &mut PgConnection,
373 pool: PgPool,
374 app_config: &ApplicationConfiguration,
375 chatbot_configuration_id: Uuid,
376 conversation_id: Uuid,
377 message: &str,
378) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
379 let (chat_request, new_message, request_estimated_tokens) =
380 ChatRequest::build_and_insert_incoming_message_to_db(
381 conn,
382 chatbot_configuration_id,
383 conversation_id,
384 message,
385 app_config,
386 )
387 .await?;
388
389 let full_response_text = Arc::new(Mutex::new(Vec::new()));
390 let done = Arc::new(AtomicBool::new(false));
391
392 let azure_config = app_config
393 .azure_configuration
394 .as_ref()
395 .ok_or_else(|| anyhow::anyhow!("Azure configuration not found"))?;
396
397 let chatbot_config = azure_config
398 .chatbot_config
399 .as_ref()
400 .ok_or_else(|| anyhow::anyhow!("Chatbot configuration not found"))?;
401
402 let api_key = chatbot_config.api_key.clone();
403 let mut url = chatbot_config.api_endpoint.clone();
404
405 url.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
407
408 let headers = build_llm_headers(&api_key)?;
409
410 let response_order_number = new_message.order_number + 1;
411
412 let response_message = models::chatbot_conversation_messages::insert(
413 conn,
414 ChatbotConversationMessage {
415 id: Uuid::new_v4(),
416 created_at: Utc::now(),
417 updated_at: Utc::now(),
418 deleted_at: None,
419 conversation_id,
420 message: None,
421 is_from_chatbot: true,
422 message_is_complete: false,
423 used_tokens: request_estimated_tokens,
424 order_number: response_order_number,
425 },
426 )
427 .await?;
428
429 let guard = RequestCancelledGuard {
431 response_message_id: response_message.id,
432 received_string: full_response_text.clone(),
433 pool: pool.clone(),
434 done: done.clone(),
435 request_estimated_tokens,
436 };
437
438 let request = REQWEST_CLIENT
439 .post(url)
440 .headers(headers)
441 .json(&chat_request)
442 .send();
443
444 let response = request.await?;
445
446 info!("Receiving chat response with {:?}", response.version());
447
448 if !response.status().is_success() {
449 let status = response.status();
450 let error_message = response.text().await?;
451 return Err(anyhow::anyhow!(
452 "Failed to send chat request. Status: {}. Error: {}",
453 status,
454 error_message
455 ));
456 }
457
458 let stream = response.bytes_stream().map_err(std::io::Error::other);
459 let reader = StreamReader::new(stream);
460 let mut lines = reader.lines();
461
462 let response_stream = async_stream::try_stream! {
463 while let Some(line) = lines.next_line().await? {
464 if !line.starts_with("data: ") {
465 continue;
466 }
467 let json_str = line.trim_start_matches("data: ");
468
469 let mut full_response_text = full_response_text.lock().await;
470 if json_str.trim() == "[DONE]" {
471 let full_response_as_string = full_response_text.join("");
472 let estimated_cost = estimate_tokens(&full_response_as_string);
473 info!(
474 "End of chatbot response stream. Estimated cost: {}. Response: {}",
475 estimated_cost, full_response_as_string
476 );
477 done.store(true, atomic::Ordering::Relaxed);
478 let mut conn = pool.acquire().await?;
479 models::chatbot_conversation_messages::update(
480 &mut conn,
481 response_message.id,
482 &full_response_as_string,
483 true,
484 request_estimated_tokens + estimated_cost,
485 ).await?;
486 break;
487 }
488 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
489 anyhow::anyhow!("Failed to parse response chunk: {}", e)
490 })?;
491
492 for choice in &response_chunk.choices {
493 if let Some(delta) = &choice.delta {
494 if let Some(content) = &delta.content {
495 full_response_text.push(content.clone());
496 let response = ChatResponse { text: content.clone() };
497 let response_as_string = serde_json::to_string(&response)?;
498 yield Bytes::from(response_as_string);
499 yield Bytes::from("\n");
500 }
501 if let Some(context) = &delta.context {
502 let citation_message_id = response_message.id;
503 let mut conn = pool.acquire().await?;
504 for (idx, cit) in context.citations.iter().enumerate() {
505 let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
506 let document_url = cit.url.clone();
507 let mut page_path = PathBuf::from(&cit.filepath);
508 page_path.set_extension("");
509 let page_id_str = page_path.file_name();
510 let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
511 let course_material_chapter_number = if let Some(id) = page_id {
512 let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
513 chapter.map(|c| c.chapter_number)
514 } else {
515 None
516 };
517
518 models::chatbot_conversation_messages_citations::insert(
519 &mut conn, ChatbotConversationMessageCitation {
520 id: Uuid::new_v4(),
521 created_at: Utc::now(),
522 updated_at: Utc::now(),
523 deleted_at: None,
524 conversation_message_id: citation_message_id,
525 conversation_id,
526 course_material_chapter_number,
527 title: cit.title.clone(),
528 content,
529 document_url,
530 citation_number: (idx+1) as i32,
531 }
532 ).await?;
533 }
534 }
535
536 }
537 }
538 }
539
540 if !done.load(atomic::Ordering::Relaxed) {
541 Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
542 }
543 };
544
545 let guarded_stream = GuardedStream::new(guard, response_stream);
548
549 Ok(Box::pin(guarded_stream))
551}