headless_lms_chatbot/
azure_chatbot.rs

1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4    Arc,
5    atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use chrono::Utc;
11use futures::{Stream, TryStreamExt};
12use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
13use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
14use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
15use headless_lms_utils::ApplicationConfiguration;
16use pin_project::pin_project;
17use serde::{Deserialize, Serialize};
18use sqlx::PgPool;
19use tokio::{io::AsyncBufReadExt, sync::Mutex};
20use tokio_util::io::StreamReader;
21use url::Url;
22
23use headless_lms_utils::url_encoding::url_decode;
24
25use crate::llm_utils::{APIMessage, MessageRole, estimate_tokens, make_streaming_llm_request};
26use crate::prelude::*;
27use crate::search_filter::SearchFilter;
28
29const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
30
31#[derive(Deserialize, Serialize, Debug)]
32pub struct ContentFilterResults {
33    pub hate: Option<ContentFilter>,
34    pub self_harm: Option<ContentFilter>,
35    pub sexual: Option<ContentFilter>,
36    pub violence: Option<ContentFilter>,
37}
38
39#[derive(Deserialize, Serialize, Debug)]
40pub struct ContentFilter {
41    pub filtered: bool,
42    pub severity: String,
43}
44
45#[derive(Deserialize, Serialize, Debug)]
46pub struct Choice {
47    pub content_filter_results: Option<ContentFilterResults>,
48    pub delta: Option<Delta>,
49    pub finish_reason: Option<String>,
50    pub index: i32,
51}
52
53#[derive(Deserialize, Serialize, Debug)]
54pub struct Delta {
55    pub content: Option<String>,
56    pub context: Option<DeltaContext>,
57}
58
59#[derive(Deserialize, Serialize, Debug)]
60pub struct DeltaContext {
61    pub citations: Vec<Citation>,
62}
63
64#[derive(Deserialize, Serialize, Debug)]
65pub struct Citation {
66    pub content: String,
67    pub title: String,
68    pub url: String,
69    pub filepath: String,
70}
71
72#[derive(Deserialize, Serialize, Debug)]
73pub struct ResponseChunk {
74    pub choices: Vec<Choice>,
75    pub created: u64,
76    pub id: String,
77    pub model: String,
78    pub object: String,
79    pub system_fingerprint: Option<String>,
80}
81
82impl From<ChatbotConversationMessage> for APIMessage {
83    fn from(message: ChatbotConversationMessage) -> Self {
84        APIMessage {
85            role: if message.is_from_chatbot {
86                MessageRole::Assistant
87            } else {
88                MessageRole::User
89            },
90            content: message.message.unwrap_or_default(),
91        }
92    }
93}
94
95#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
96pub struct ThinkingParams {
97    pub max_completion_tokens: Option<i32>,
98    pub verbosity: Option<VerbosityLevel>,
99    pub reasoning_effort: Option<ReasoningEffortLevel>,
100}
101
102#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
103pub struct NonThinkingParams {
104    pub max_tokens: Option<i32>,
105    pub temperature: Option<f32>,
106    pub top_p: Option<f32>,
107    pub frequency_penalty: Option<f32>,
108    pub presence_penalty: Option<f32>,
109}
110
111#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
112#[serde(untagged)]
113pub enum LLMRequestParams {
114    Thinking(ThinkingParams),
115    NonThinking(NonThinkingParams),
116    None,
117}
118
119#[derive(Serialize, Deserialize, Debug)]
120pub struct LLMRequest {
121    pub messages: Vec<APIMessage>,
122    #[serde(skip_serializing_if = "Vec::is_empty")]
123    pub data_sources: Vec<DataSource>,
124    #[serde(flatten)]
125    pub params: LLMRequestParams,
126    pub stop: Option<String>,
127}
128
129impl LLMRequest {
130    pub async fn build_and_insert_incoming_message_to_db(
131        conn: &mut PgConnection,
132        chatbot_configuration_id: Uuid,
133        conversation_id: Uuid,
134        message: &str,
135        app_config: &ApplicationConfiguration,
136    ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
137        let index_name = Url::parse(&app_config.base_url)?
138            .host_str()
139            .expect("BASE_URL must have a host")
140            .replace(".", "-");
141
142        let configuration =
143            models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
144
145        let conversation_messages =
146            models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
147                .await?;
148
149        let new_order_number = conversation_messages
150            .iter()
151            .map(|m| m.order_number)
152            .max()
153            .unwrap_or(0)
154            + 1;
155
156        let new_message = models::chatbot_conversation_messages::insert(
157            conn,
158            ChatbotConversationMessage {
159                id: Uuid::new_v4(),
160                created_at: Utc::now(),
161                updated_at: Utc::now(),
162                deleted_at: None,
163                conversation_id,
164                message: Some(message.to_string()),
165                is_from_chatbot: false,
166                message_is_complete: true,
167                used_tokens: estimate_tokens(message),
168                order_number: new_order_number,
169            },
170        )
171        .await?;
172
173        let mut api_chat_messages: Vec<APIMessage> =
174            conversation_messages.into_iter().map(Into::into).collect();
175
176        api_chat_messages.push(new_message.clone().into());
177
178        api_chat_messages.insert(
179            0,
180            APIMessage {
181                role: MessageRole::System,
182                content: configuration.prompt.clone(),
183            },
184        );
185
186        let data_sources = if configuration.use_azure_search {
187            let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
188                anyhow::anyhow!("Azure configuration is missing from the application configuration")
189            })?;
190
191            let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
192                anyhow::anyhow!(
193                    "Azure search configuration is missing from the Azure configuration"
194                )
195            })?;
196
197            let query_type = if configuration.use_semantic_reranking {
198                "vector_semantic_hybrid"
199            } else {
200                "vector_simple_hybrid"
201            };
202
203            vec![DataSource {
204                data_type: "azure_search".to_string(),
205                parameters: DataSourceParameters {
206                    endpoint: search_config.search_endpoint.to_string(),
207                    authentication: DataSourceParametersAuthentication {
208                        auth_type: "api_key".to_string(),
209                        key: search_config.search_api_key.clone(),
210                    },
211                    index_name,
212                    query_type: query_type.to_string(),
213                    semantic_configuration: "default".to_string(),
214                    embedding_dependency: EmbeddingDependency {
215                        dep_type: "deployment_name".to_string(),
216                        deployment_name: search_config.vectorizer_deployment_id.clone(),
217                    },
218                    in_scope: false,
219                    top_n_documents: 15,
220                    strictness: 3,
221                    filter: Some(
222                        SearchFilter::eq("course_id", configuration.course_id.to_string())
223                            .to_odata()?,
224                    ),
225                    fields_mapping: FieldsMapping {
226                        content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
227                        content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
228                        filepath_field: "filepath".to_string(),
229                        title_field: "title".to_string(),
230                        url_field: "url".to_string(),
231                        vector_fields: vec!["text_vector".to_string()],
232                    },
233                },
234            }]
235        } else {
236            Vec::new()
237        };
238
239        let serialized_messages = serde_json::to_string(&api_chat_messages)?;
240        let request_estimated_tokens = estimate_tokens(&serialized_messages);
241
242        let params = if configuration.thinking_model {
243            LLMRequestParams::Thinking(ThinkingParams {
244                max_completion_tokens: Some(configuration.max_completion_tokens),
245                reasoning_effort: Some(configuration.reasoning_effort),
246                verbosity: Some(configuration.verbosity),
247            })
248        } else {
249            LLMRequestParams::NonThinking(NonThinkingParams {
250                max_tokens: Some(configuration.response_max_tokens),
251                temperature: Some(configuration.temperature),
252                top_p: Some(configuration.top_p),
253                frequency_penalty: Some(configuration.frequency_penalty),
254                presence_penalty: Some(configuration.presence_penalty),
255            })
256        };
257
258        Ok((
259            Self {
260                messages: api_chat_messages,
261                data_sources,
262                params,
263                stop: None,
264            },
265            new_message,
266            request_estimated_tokens,
267        ))
268    }
269}
270
271#[derive(Serialize, Deserialize, Debug)]
272pub struct DataSource {
273    #[serde(rename = "type")]
274    pub data_type: String,
275    pub parameters: DataSourceParameters,
276}
277
278#[derive(Serialize, Deserialize, Debug)]
279pub struct DataSourceParameters {
280    pub endpoint: String,
281    pub authentication: DataSourceParametersAuthentication,
282    pub index_name: String,
283    pub query_type: String,
284    pub embedding_dependency: EmbeddingDependency,
285    pub in_scope: bool,
286    pub top_n_documents: i32,
287    pub strictness: i32,
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub filter: Option<String>,
290    pub fields_mapping: FieldsMapping,
291    pub semantic_configuration: String,
292}
293
294#[derive(Serialize, Deserialize, Debug)]
295pub struct DataSourceParametersAuthentication {
296    #[serde(rename = "type")]
297    pub auth_type: String,
298    pub key: String,
299}
300
301#[derive(Serialize, Deserialize, Debug)]
302pub struct EmbeddingDependency {
303    #[serde(rename = "type")]
304    pub dep_type: String,
305    pub deployment_name: String,
306}
307
308#[derive(Serialize, Deserialize, Debug)]
309pub struct FieldsMapping {
310    pub content_fields_separator: String,
311    pub content_fields: Vec<String>,
312    pub filepath_field: String,
313    pub title_field: String,
314    pub url_field: String,
315    pub vector_fields: Vec<String>,
316}
317
318#[derive(Serialize, Deserialize, Debug)]
319pub struct ChatResponse {
320    pub text: String,
321}
322
323/// Custom stream that encapsulates both the response stream and the cancellation guard. Makes sure that the guard is always dropped when the stream is dropped.
324#[pin_project]
325struct GuardedStream<S> {
326    guard: RequestCancelledGuard,
327    #[pin]
328    stream: S,
329}
330
331impl<S> GuardedStream<S> {
332    fn new(guard: RequestCancelledGuard, stream: S) -> Self {
333        Self { guard, stream }
334    }
335}
336
337impl<S> Stream for GuardedStream<S>
338where
339    S: Stream<Item = anyhow::Result<Bytes>> + Send,
340{
341    type Item = S::Item;
342
343    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
344        let this = self.project();
345        this.stream.poll_next(cx)
346    }
347}
348
349struct RequestCancelledGuard {
350    response_message_id: Uuid,
351    received_string: Arc<Mutex<Vec<String>>>,
352    pool: PgPool,
353    done: Arc<AtomicBool>,
354    request_estimated_tokens: i32,
355}
356
357impl Drop for RequestCancelledGuard {
358    fn drop(&mut self) {
359        if self.done.load(atomic::Ordering::Relaxed) {
360            return;
361        }
362        warn!("Request was not cancelled. Cleaning up.");
363        let response_message_id = self.response_message_id;
364        let received_string = self.received_string.clone();
365        let pool = self.pool.clone();
366        let request_estimated_tokens = self.request_estimated_tokens;
367        tokio::spawn(async move {
368            info!("Verifying the received message has been handled");
369            let mut conn = pool.acquire().await.expect("Could not acquire connection");
370            let full_response_text = received_string.lock().await;
371            if full_response_text.is_empty() {
372                info!("No response received. Deleting the response message");
373                models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
374                    .await
375                    .expect("Could not delete response message");
376                return;
377            }
378            info!("Response received but not completed. Saving the text received so far.");
379            let full_response_as_string = full_response_text.join("");
380            let estimated_cost = estimate_tokens(&full_response_as_string);
381            info!(
382                "End of chatbot response stream. Estimated cost: {}. Response: {}",
383                estimated_cost, full_response_as_string
384            );
385
386            // Update with request_estimated_tokens + estimated_cost
387            models::chatbot_conversation_messages::update(
388                &mut conn,
389                response_message_id,
390                &full_response_as_string,
391                true,
392                request_estimated_tokens + estimated_cost,
393            )
394            .await
395            .expect("Could not update response message");
396        });
397    }
398}
399
400pub async fn send_chat_request_and_parse_stream(
401    conn: &mut PgConnection,
402    pool: PgPool,
403    app_config: &ApplicationConfiguration,
404    chatbot_configuration_id: Uuid,
405    conversation_id: Uuid,
406    message: &str,
407) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
408    let (chat_request, new_message, request_estimated_tokens) =
409        LLMRequest::build_and_insert_incoming_message_to_db(
410            conn,
411            chatbot_configuration_id,
412            conversation_id,
413            message,
414            app_config,
415        )
416        .await?;
417
418    let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
419        conn,
420        chatbot_configuration_id,
421    )
422    .await?;
423
424    let full_response_text = Arc::new(Mutex::new(Vec::new()));
425    let done = Arc::new(AtomicBool::new(false));
426
427    let response_order_number = new_message.order_number + 1;
428
429    let response_message = models::chatbot_conversation_messages::insert(
430        conn,
431        ChatbotConversationMessage {
432            id: Uuid::new_v4(),
433            created_at: Utc::now(),
434            updated_at: Utc::now(),
435            deleted_at: None,
436            conversation_id,
437            message: None,
438            is_from_chatbot: true,
439            message_is_complete: false,
440            used_tokens: request_estimated_tokens,
441            order_number: response_order_number,
442        },
443    )
444    .await?;
445
446    // Instantiate the guard before creating the stream.
447    let guard = RequestCancelledGuard {
448        response_message_id: response_message.id,
449        received_string: full_response_text.clone(),
450        pool: pool.clone(),
451        done: done.clone(),
452        request_estimated_tokens,
453    };
454
455    let response =
456        make_streaming_llm_request(chat_request, &model.deployment_name, app_config).await?;
457
458    info!("Receiving chat response with {:?}", response.version());
459
460    if !response.status().is_success() {
461        let status = response.status();
462        let error_message = response.text().await?;
463        return Err(anyhow::anyhow!(
464            "Failed to send chat request. Status: {}. Error: {}",
465            status,
466            error_message
467        ));
468    }
469
470    let stream = response.bytes_stream().map_err(std::io::Error::other);
471    let reader = StreamReader::new(stream);
472    let mut lines = reader.lines();
473
474    let response_stream = async_stream::try_stream! {
475        while let Some(line) = lines.next_line().await? {
476            if !line.starts_with("data: ") {
477                continue;
478            }
479            let json_str = line.trim_start_matches("data: ");
480
481            let mut full_response_text = full_response_text.lock().await;
482            if json_str.trim() == "[DONE]" {
483                let full_response_as_string = full_response_text.join("");
484                let estimated_cost = estimate_tokens(&full_response_as_string);
485                info!(
486                    "End of chatbot response stream. Estimated cost: {}. Response: {}",
487                    estimated_cost, full_response_as_string
488                );
489                done.store(true, atomic::Ordering::Relaxed);
490                let mut conn = pool.acquire().await?;
491                models::chatbot_conversation_messages::update(
492                    &mut conn,
493                    response_message.id,
494                    &full_response_as_string,
495                    true,
496                    request_estimated_tokens + estimated_cost,
497                ).await?;
498                break;
499            }
500            let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
501                anyhow::anyhow!("Failed to parse response chunk: {}", e)
502            })?;
503
504            for choice in &response_chunk.choices {
505                if let Some(delta) = &choice.delta {
506                    if let Some(content) = &delta.content {
507                        full_response_text.push(content.clone());
508                        let response = ChatResponse { text: content.clone() };
509                        let response_as_string = serde_json::to_string(&response)?;
510                        yield Bytes::from(response_as_string);
511                        yield Bytes::from("\n");
512                    }
513                    if let Some(context) = &delta.context {
514                        let citation_message_id = response_message.id;
515                        let mut conn = pool.acquire().await?;
516                        for (idx, cit) in context.citations.iter().enumerate() {
517                            let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
518                            let split = content.split_once(CONTENT_FIELD_SEPARATOR);
519                            if split.is_none() {
520                                error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
521                            }
522                            let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
523
524                            // The title and URL come from Azure Blob Storage metadata, which was URL-encoded
525                            // (percent-encoded) because Azure Blob Storage metadata values must be ASCII-only.
526                            // We decode them back to their original UTF-8 strings before storing in the database.
527                            let decoded_title = url_decode(&cit.title)?;
528                            let decoded_url = url_decode(&cit.url)?;
529
530                            let mut page_path = PathBuf::from(&cit.filepath);
531                            page_path.set_extension("");
532                            let page_id_str = page_path.file_name();
533                            let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
534                            let course_material_chapter_number = if let Some(id) = page_id {
535                                let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
536                                chapter.map(|c| c.chapter_number)
537                            } else {
538                                None
539                            };
540
541                            models::chatbot_conversation_messages_citations::insert(
542                                &mut conn, ChatbotConversationMessageCitation {
543                                    id: Uuid::new_v4(),
544                                    created_at: Utc::now(),
545                                    updated_at: Utc::now(),
546                                    deleted_at: None,
547                                    conversation_message_id: citation_message_id,
548                                    conversation_id,
549                                    course_material_chapter_number,
550                                    title: decoded_title,
551                                    content: cleaned_content,
552                                    document_url: decoded_url,
553                                    citation_number: (idx+1) as i32,
554                                }
555                            ).await?;
556                        }
557                    }
558
559                }
560            }
561        }
562
563        if !done.load(atomic::Ordering::Relaxed) {
564            Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
565        }
566    };
567
568    // Encapsulate the stream and the guard within GuardedStream. This moves the request guard into the stream and ensures that it is dropped when the stream is dropped.
569    // This way we do cleanup only when the stream is dropped and not when this function returns.
570    let guarded_stream = GuardedStream::new(guard, response_stream);
571
572    // Box and pin the GuardedStream to satisfy the Unpin requirement
573    Ok(Box::pin(guarded_stream))
574}