headless_lms_chatbot/
azure_chatbot.rs

1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4    Arc,
5    atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use chrono::Utc;
11use futures::{Stream, TryStreamExt};
12use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
13use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
14use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
15use headless_lms_utils::ApplicationConfiguration;
16use pin_project::pin_project;
17use serde::{Deserialize, Serialize};
18use sqlx::PgPool;
19use tokio::{io::AsyncBufReadExt, sync::Mutex};
20use tokio_util::io::StreamReader;
21use url::Url;
22
23use crate::llm_utils::{APIMessage, MessageRole, estimate_tokens, make_streaming_llm_request};
24use crate::prelude::*;
25use crate::search_filter::SearchFilter;
26
27const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
28
29#[derive(Deserialize, Serialize, Debug)]
30pub struct ContentFilterResults {
31    pub hate: Option<ContentFilter>,
32    pub self_harm: Option<ContentFilter>,
33    pub sexual: Option<ContentFilter>,
34    pub violence: Option<ContentFilter>,
35}
36
37#[derive(Deserialize, Serialize, Debug)]
38pub struct ContentFilter {
39    pub filtered: bool,
40    pub severity: String,
41}
42
43#[derive(Deserialize, Serialize, Debug)]
44pub struct Choice {
45    pub content_filter_results: Option<ContentFilterResults>,
46    pub delta: Option<Delta>,
47    pub finish_reason: Option<String>,
48    pub index: i32,
49}
50
51#[derive(Deserialize, Serialize, Debug)]
52pub struct Delta {
53    pub content: Option<String>,
54    pub context: Option<DeltaContext>,
55}
56
57#[derive(Deserialize, Serialize, Debug)]
58pub struct DeltaContext {
59    pub citations: Vec<Citation>,
60}
61
62#[derive(Deserialize, Serialize, Debug)]
63pub struct Citation {
64    pub content: String,
65    pub title: String,
66    pub url: String,
67    pub filepath: String,
68}
69
70#[derive(Deserialize, Serialize, Debug)]
71pub struct ResponseChunk {
72    pub choices: Vec<Choice>,
73    pub created: u64,
74    pub id: String,
75    pub model: String,
76    pub object: String,
77    pub system_fingerprint: Option<String>,
78}
79
80impl From<ChatbotConversationMessage> for APIMessage {
81    fn from(message: ChatbotConversationMessage) -> Self {
82        APIMessage {
83            role: if message.is_from_chatbot {
84                MessageRole::Assistant
85            } else {
86                MessageRole::User
87            },
88            content: message.message.unwrap_or_default(),
89        }
90    }
91}
92
93#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
94pub struct ThinkingParams {
95    pub max_completion_tokens: Option<i32>,
96    pub verbosity: Option<VerbosityLevel>,
97    pub reasoning_effort: Option<ReasoningEffortLevel>,
98}
99
100#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
101pub struct NonThinkingParams {
102    pub max_tokens: Option<i32>,
103    pub temperature: Option<f32>,
104    pub top_p: Option<f32>,
105    pub frequency_penalty: Option<f32>,
106    pub presence_penalty: Option<f32>,
107}
108
109#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
110#[serde(untagged)]
111pub enum LLMRequestParams {
112    Thinking(ThinkingParams),
113    NonThinking(NonThinkingParams),
114    None,
115}
116
117#[derive(Serialize, Deserialize, Debug)]
118pub struct LLMRequest {
119    pub messages: Vec<APIMessage>,
120    #[serde(skip_serializing_if = "Vec::is_empty")]
121    pub data_sources: Vec<DataSource>,
122    #[serde(flatten)]
123    pub params: LLMRequestParams,
124    pub stop: Option<String>,
125}
126
127impl LLMRequest {
128    pub async fn build_and_insert_incoming_message_to_db(
129        conn: &mut PgConnection,
130        chatbot_configuration_id: Uuid,
131        conversation_id: Uuid,
132        message: &str,
133        app_config: &ApplicationConfiguration,
134    ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
135        let index_name = Url::parse(&app_config.base_url)?
136            .host_str()
137            .expect("BASE_URL must have a host")
138            .replace(".", "-");
139
140        let configuration =
141            models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
142
143        let conversation_messages =
144            models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
145                .await?;
146
147        let new_order_number = conversation_messages
148            .iter()
149            .map(|m| m.order_number)
150            .max()
151            .unwrap_or(0)
152            + 1;
153
154        let new_message = models::chatbot_conversation_messages::insert(
155            conn,
156            ChatbotConversationMessage {
157                id: Uuid::new_v4(),
158                created_at: Utc::now(),
159                updated_at: Utc::now(),
160                deleted_at: None,
161                conversation_id,
162                message: Some(message.to_string()),
163                is_from_chatbot: false,
164                message_is_complete: true,
165                used_tokens: estimate_tokens(message),
166                order_number: new_order_number,
167            },
168        )
169        .await?;
170
171        let mut api_chat_messages: Vec<APIMessage> =
172            conversation_messages.into_iter().map(Into::into).collect();
173
174        api_chat_messages.push(new_message.clone().into());
175
176        api_chat_messages.insert(
177            0,
178            APIMessage {
179                role: MessageRole::System,
180                content: configuration.prompt.clone(),
181            },
182        );
183
184        let data_sources = if configuration.use_azure_search {
185            let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
186                anyhow::anyhow!("Azure configuration is missing from the application configuration")
187            })?;
188
189            let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
190                anyhow::anyhow!(
191                    "Azure search configuration is missing from the Azure configuration"
192                )
193            })?;
194
195            let query_type = if configuration.use_semantic_reranking {
196                "vector_semantic_hybrid"
197            } else {
198                "vector_simple_hybrid"
199            };
200
201            vec![DataSource {
202                data_type: "azure_search".to_string(),
203                parameters: DataSourceParameters {
204                    endpoint: search_config.search_endpoint.to_string(),
205                    authentication: DataSourceParametersAuthentication {
206                        auth_type: "api_key".to_string(),
207                        key: search_config.search_api_key.clone(),
208                    },
209                    index_name,
210                    query_type: query_type.to_string(),
211                    semantic_configuration: "default".to_string(),
212                    embedding_dependency: EmbeddingDependency {
213                        dep_type: "deployment_name".to_string(),
214                        deployment_name: search_config.vectorizer_deployment_id.clone(),
215                    },
216                    in_scope: false,
217                    top_n_documents: 15,
218                    strictness: 3,
219                    filter: Some(
220                        SearchFilter::eq("course_id", configuration.course_id.to_string())
221                            .to_odata()?,
222                    ),
223                    fields_mapping: FieldsMapping {
224                        content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
225                        content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
226                        filepath_field: "filepath".to_string(),
227                        title_field: "title".to_string(),
228                        url_field: "url".to_string(),
229                        vector_fields: vec!["text_vector".to_string()],
230                    },
231                },
232            }]
233        } else {
234            Vec::new()
235        };
236
237        let serialized_messages = serde_json::to_string(&api_chat_messages)?;
238        let request_estimated_tokens = estimate_tokens(&serialized_messages);
239
240        let params = if configuration.thinking_model {
241            LLMRequestParams::Thinking(ThinkingParams {
242                max_completion_tokens: Some(configuration.max_completion_tokens),
243                reasoning_effort: Some(configuration.reasoning_effort),
244                verbosity: Some(configuration.verbosity),
245            })
246        } else {
247            LLMRequestParams::NonThinking(NonThinkingParams {
248                max_tokens: Some(configuration.response_max_tokens),
249                temperature: Some(configuration.temperature),
250                top_p: Some(configuration.top_p),
251                frequency_penalty: Some(configuration.frequency_penalty),
252                presence_penalty: Some(configuration.presence_penalty),
253            })
254        };
255
256        Ok((
257            Self {
258                messages: api_chat_messages,
259                data_sources,
260                params,
261                stop: None,
262            },
263            new_message,
264            request_estimated_tokens,
265        ))
266    }
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct DataSource {
271    #[serde(rename = "type")]
272    pub data_type: String,
273    pub parameters: DataSourceParameters,
274}
275
276#[derive(Serialize, Deserialize, Debug)]
277pub struct DataSourceParameters {
278    pub endpoint: String,
279    pub authentication: DataSourceParametersAuthentication,
280    pub index_name: String,
281    pub query_type: String,
282    pub embedding_dependency: EmbeddingDependency,
283    pub in_scope: bool,
284    pub top_n_documents: i32,
285    pub strictness: i32,
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub filter: Option<String>,
288    pub fields_mapping: FieldsMapping,
289    pub semantic_configuration: String,
290}
291
292#[derive(Serialize, Deserialize, Debug)]
293pub struct DataSourceParametersAuthentication {
294    #[serde(rename = "type")]
295    pub auth_type: String,
296    pub key: String,
297}
298
299#[derive(Serialize, Deserialize, Debug)]
300pub struct EmbeddingDependency {
301    #[serde(rename = "type")]
302    pub dep_type: String,
303    pub deployment_name: String,
304}
305
306#[derive(Serialize, Deserialize, Debug)]
307pub struct FieldsMapping {
308    pub content_fields_separator: String,
309    pub content_fields: Vec<String>,
310    pub filepath_field: String,
311    pub title_field: String,
312    pub url_field: String,
313    pub vector_fields: Vec<String>,
314}
315
316#[derive(Serialize, Deserialize, Debug)]
317pub struct ChatResponse {
318    pub text: String,
319}
320
321/// Custom stream that encapsulates both the response stream and the cancellation guard. Makes sure that the guard is always dropped when the stream is dropped.
322#[pin_project]
323struct GuardedStream<S> {
324    guard: RequestCancelledGuard,
325    #[pin]
326    stream: S,
327}
328
329impl<S> GuardedStream<S> {
330    fn new(guard: RequestCancelledGuard, stream: S) -> Self {
331        Self { guard, stream }
332    }
333}
334
335impl<S> Stream for GuardedStream<S>
336where
337    S: Stream<Item = anyhow::Result<Bytes>> + Send,
338{
339    type Item = S::Item;
340
341    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
342        let this = self.project();
343        this.stream.poll_next(cx)
344    }
345}
346
347struct RequestCancelledGuard {
348    response_message_id: Uuid,
349    received_string: Arc<Mutex<Vec<String>>>,
350    pool: PgPool,
351    done: Arc<AtomicBool>,
352    request_estimated_tokens: i32,
353}
354
355impl Drop for RequestCancelledGuard {
356    fn drop(&mut self) {
357        if self.done.load(atomic::Ordering::Relaxed) {
358            return;
359        }
360        warn!("Request was not cancelled. Cleaning up.");
361        let response_message_id = self.response_message_id;
362        let received_string = self.received_string.clone();
363        let pool = self.pool.clone();
364        let request_estimated_tokens = self.request_estimated_tokens;
365        tokio::spawn(async move {
366            info!("Verifying the received message has been handled");
367            let mut conn = pool.acquire().await.expect("Could not acquire connection");
368            let full_response_text = received_string.lock().await;
369            if full_response_text.is_empty() {
370                info!("No response received. Deleting the response message");
371                models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
372                    .await
373                    .expect("Could not delete response message");
374                return;
375            }
376            info!("Response received but not completed. Saving the text received so far.");
377            let full_response_as_string = full_response_text.join("");
378            let estimated_cost = estimate_tokens(&full_response_as_string);
379            info!(
380                "End of chatbot response stream. Estimated cost: {}. Response: {}",
381                estimated_cost, full_response_as_string
382            );
383
384            // Update with request_estimated_tokens + estimated_cost
385            models::chatbot_conversation_messages::update(
386                &mut conn,
387                response_message_id,
388                &full_response_as_string,
389                true,
390                request_estimated_tokens + estimated_cost,
391            )
392            .await
393            .expect("Could not update response message");
394        });
395    }
396}
397
398pub async fn send_chat_request_and_parse_stream(
399    conn: &mut PgConnection,
400    pool: PgPool,
401    app_config: &ApplicationConfiguration,
402    chatbot_configuration_id: Uuid,
403    conversation_id: Uuid,
404    message: &str,
405) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
406    let (chat_request, new_message, request_estimated_tokens) =
407        LLMRequest::build_and_insert_incoming_message_to_db(
408            conn,
409            chatbot_configuration_id,
410            conversation_id,
411            message,
412            app_config,
413        )
414        .await?;
415
416    let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
417        conn,
418        chatbot_configuration_id,
419    )
420    .await?;
421
422    let full_response_text = Arc::new(Mutex::new(Vec::new()));
423    let done = Arc::new(AtomicBool::new(false));
424
425    let response_order_number = new_message.order_number + 1;
426
427    let response_message = models::chatbot_conversation_messages::insert(
428        conn,
429        ChatbotConversationMessage {
430            id: Uuid::new_v4(),
431            created_at: Utc::now(),
432            updated_at: Utc::now(),
433            deleted_at: None,
434            conversation_id,
435            message: None,
436            is_from_chatbot: true,
437            message_is_complete: false,
438            used_tokens: request_estimated_tokens,
439            order_number: response_order_number,
440        },
441    )
442    .await?;
443
444    // Instantiate the guard before creating the stream.
445    let guard = RequestCancelledGuard {
446        response_message_id: response_message.id,
447        received_string: full_response_text.clone(),
448        pool: pool.clone(),
449        done: done.clone(),
450        request_estimated_tokens,
451    };
452
453    let response =
454        make_streaming_llm_request(chat_request, &model.deployment_name, app_config).await?;
455
456    info!("Receiving chat response with {:?}", response.version());
457
458    if !response.status().is_success() {
459        let status = response.status();
460        let error_message = response.text().await?;
461        return Err(anyhow::anyhow!(
462            "Failed to send chat request. Status: {}. Error: {}",
463            status,
464            error_message
465        ));
466    }
467
468    let stream = response.bytes_stream().map_err(std::io::Error::other);
469    let reader = StreamReader::new(stream);
470    let mut lines = reader.lines();
471
472    let response_stream = async_stream::try_stream! {
473        while let Some(line) = lines.next_line().await? {
474            if !line.starts_with("data: ") {
475                continue;
476            }
477            let json_str = line.trim_start_matches("data: ");
478
479            let mut full_response_text = full_response_text.lock().await;
480            if json_str.trim() == "[DONE]" {
481                let full_response_as_string = full_response_text.join("");
482                let estimated_cost = estimate_tokens(&full_response_as_string);
483                info!(
484                    "End of chatbot response stream. Estimated cost: {}. Response: {}",
485                    estimated_cost, full_response_as_string
486                );
487                done.store(true, atomic::Ordering::Relaxed);
488                let mut conn = pool.acquire().await?;
489                models::chatbot_conversation_messages::update(
490                    &mut conn,
491                    response_message.id,
492                    &full_response_as_string,
493                    true,
494                    request_estimated_tokens + estimated_cost,
495                ).await?;
496                break;
497            }
498            let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
499                anyhow::anyhow!("Failed to parse response chunk: {}", e)
500            })?;
501
502            for choice in &response_chunk.choices {
503                if let Some(delta) = &choice.delta {
504                    if let Some(content) = &delta.content {
505                        full_response_text.push(content.clone());
506                        let response = ChatResponse { text: content.clone() };
507                        let response_as_string = serde_json::to_string(&response)?;
508                        yield Bytes::from(response_as_string);
509                        yield Bytes::from("\n");
510                    }
511                    if let Some(context) = &delta.context {
512                        let citation_message_id = response_message.id;
513                        let mut conn = pool.acquire().await?;
514                        for (idx, cit) in context.citations.iter().enumerate() {
515                            let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
516                            let split = content.split_once(CONTENT_FIELD_SEPARATOR);
517                            if split.is_none() {
518                                error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
519                            }
520                            let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
521
522                            let document_url = cit.url.clone();
523                            let mut page_path = PathBuf::from(&cit.filepath);
524                            page_path.set_extension("");
525                            let page_id_str = page_path.file_name();
526                            let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
527                            let course_material_chapter_number = if let Some(id) = page_id {
528                                let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
529                                chapter.map(|c| c.chapter_number)
530                            } else {
531                                None
532                            };
533
534                            models::chatbot_conversation_messages_citations::insert(
535                                &mut conn, ChatbotConversationMessageCitation {
536                                    id: Uuid::new_v4(),
537                                    created_at: Utc::now(),
538                                    updated_at: Utc::now(),
539                                    deleted_at: None,
540                                    conversation_message_id: citation_message_id,
541                                    conversation_id,
542                                    course_material_chapter_number,
543                                    title: cit.title.clone(),
544                                    content: cleaned_content,
545                                    document_url,
546                                    citation_number: (idx+1) as i32,
547                                }
548                            ).await?;
549                        }
550                    }
551
552                }
553            }
554        }
555
556        if !done.load(atomic::Ordering::Relaxed) {
557            Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
558        }
559    };
560
561    // Encapsulate the stream and the guard within GuardedStream. This moves the request guard into the stream and ensures that it is dropped when the stream is dropped.
562    // This way we do cleanup only when the stream is dropped and not when this function returns.
563    let guarded_stream = GuardedStream::new(guard, response_stream);
564
565    // Box and pin the GuardedStream to satisfy the Unpin requirement
566    Ok(Box::pin(guarded_stream))
567}