headless_lms_chatbot/
azure_chatbot.rs

1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4    Arc,
5    atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use chrono::Utc;
11use futures::{Stream, TryStreamExt};
12use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
13use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
14use headless_lms_utils::{ApplicationConfiguration, http::REQWEST_CLIENT};
15use pin_project::pin_project;
16use serde::{Deserialize, Serialize};
17use sqlx::PgPool;
18use tokio::{io::AsyncBufReadExt, sync::Mutex};
19use tokio_util::io::StreamReader;
20use url::Url;
21
22use crate::llm_utils::{LLM_API_VERSION, build_llm_headers, estimate_tokens};
23use crate::prelude::*;
24use crate::search_filter::SearchFilter;
25
26const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
27
28#[derive(Deserialize, Serialize, Debug)]
29pub struct ContentFilterResults {
30    pub hate: Option<ContentFilter>,
31    pub self_harm: Option<ContentFilter>,
32    pub sexual: Option<ContentFilter>,
33    pub violence: Option<ContentFilter>,
34}
35
36#[derive(Deserialize, Serialize, Debug)]
37pub struct ContentFilter {
38    pub filtered: bool,
39    pub severity: String,
40}
41
42#[derive(Deserialize, Serialize, Debug)]
43pub struct Choice {
44    pub content_filter_results: Option<ContentFilterResults>,
45    pub delta: Option<Delta>,
46    pub finish_reason: Option<String>,
47    pub index: i32,
48}
49
50#[derive(Deserialize, Serialize, Debug)]
51pub struct Delta {
52    pub content: Option<String>,
53    pub context: Option<DeltaContext>,
54}
55
56#[derive(Deserialize, Serialize, Debug)]
57pub struct DeltaContext {
58    pub citations: Vec<Citation>,
59}
60
61#[derive(Deserialize, Serialize, Debug)]
62pub struct Citation {
63    pub content: String,
64    pub title: String,
65    pub url: String,
66    pub filepath: String,
67}
68
69#[derive(Deserialize, Serialize, Debug)]
70pub struct ResponseChunk {
71    pub choices: Vec<Choice>,
72    pub created: u64,
73    pub id: String,
74    pub model: String,
75    pub object: String,
76    pub system_fingerprint: Option<String>,
77}
78
79/// The format accepted by the API.
80#[derive(Serialize, Deserialize, Debug)]
81pub struct ApiChatMessage {
82    pub role: String,
83    pub content: String,
84}
85
86impl From<ChatbotConversationMessage> for ApiChatMessage {
87    fn from(message: ChatbotConversationMessage) -> Self {
88        ApiChatMessage {
89            role: if message.is_from_chatbot {
90                "assistant".to_string()
91            } else {
92                "user".to_string()
93            },
94            content: message.message.unwrap_or_default(),
95        }
96    }
97}
98
99#[derive(Serialize, Deserialize, Debug)]
100pub struct ChatRequest {
101    pub messages: Vec<ApiChatMessage>,
102    #[serde(skip_serializing_if = "Vec::is_empty")]
103    pub data_sources: Vec<DataSource>,
104    pub temperature: f32,
105    pub top_p: f32,
106    pub frequency_penalty: f32,
107    pub presence_penalty: f32,
108    pub max_tokens: i32,
109    pub stop: Option<String>,
110    pub stream: bool,
111}
112
113impl ChatRequest {
114    pub async fn build_and_insert_incoming_message_to_db(
115        conn: &mut PgConnection,
116        chatbot_configuration_id: Uuid,
117        conversation_id: Uuid,
118        message: &str,
119        app_config: &ApplicationConfiguration,
120    ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
121        let index_name = Url::parse(&app_config.base_url)?
122            .host_str()
123            .expect("BASE_URL must have a host")
124            .replace(".", "-");
125
126        let configuration =
127            models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
128
129        let conversation_messages =
130            models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
131                .await?;
132
133        let new_order_number = conversation_messages
134            .iter()
135            .map(|m| m.order_number)
136            .max()
137            .unwrap_or(0)
138            + 1;
139
140        let new_message = models::chatbot_conversation_messages::insert(
141            conn,
142            ChatbotConversationMessage {
143                id: Uuid::new_v4(),
144                created_at: Utc::now(),
145                updated_at: Utc::now(),
146                deleted_at: None,
147                conversation_id,
148                message: Some(message.to_string()),
149                is_from_chatbot: false,
150                message_is_complete: true,
151                used_tokens: estimate_tokens(message),
152                order_number: new_order_number,
153            },
154        )
155        .await?;
156
157        let mut api_chat_messages: Vec<ApiChatMessage> =
158            conversation_messages.into_iter().map(Into::into).collect();
159
160        api_chat_messages.push(new_message.clone().into());
161
162        api_chat_messages.insert(
163            0,
164            ApiChatMessage {
165                role: "system".to_string(),
166                content: configuration.prompt.clone(),
167            },
168        );
169
170        let data_sources = if configuration.use_azure_search {
171            let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
172                anyhow::anyhow!("Azure configuration is missing from the application configuration")
173            })?;
174
175            let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
176                anyhow::anyhow!(
177                    "Azure search configuration is missing from the Azure configuration"
178                )
179            })?;
180
181            let query_type = if configuration.use_semantic_reranking {
182                "vector_semantic_hybrid"
183            } else {
184                "vector_simple_hybrid"
185            };
186
187            vec![DataSource {
188                data_type: "azure_search".to_string(),
189                parameters: DataSourceParameters {
190                    endpoint: search_config.search_endpoint.to_string(),
191                    authentication: DataSourceParametersAuthentication {
192                        auth_type: "api_key".to_string(),
193                        key: search_config.search_api_key.clone(),
194                    },
195                    index_name,
196                    query_type: query_type.to_string(),
197                    semantic_configuration: "default".to_string(),
198                    embedding_dependency: EmbeddingDependency {
199                        dep_type: "deployment_name".to_string(),
200                        deployment_name: search_config.vectorizer_deployment_id.clone(),
201                    },
202                    in_scope: false,
203                    top_n_documents: 15,
204                    strictness: 3,
205                    filter: Some(
206                        SearchFilter::eq("course_id", configuration.course_id.to_string())
207                            .to_odata()?,
208                    ),
209                    fields_mapping: FieldsMapping {
210                        content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
211                        content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
212                        filepath_field: "filepath".to_string(),
213                        title_field: "title".to_string(),
214                        url_field: "url".to_string(),
215                        vector_fields: vec!["text_vector".to_string()],
216                    },
217                },
218            }]
219        } else {
220            Vec::new()
221        };
222
223        let serialized_messages = serde_json::to_string(&api_chat_messages)?;
224        let request_estimated_tokens = estimate_tokens(&serialized_messages);
225
226        Ok((
227            Self {
228                messages: api_chat_messages,
229                data_sources,
230                temperature: configuration.temperature,
231                top_p: configuration.top_p,
232                frequency_penalty: configuration.frequency_penalty,
233                presence_penalty: configuration.presence_penalty,
234                max_tokens: configuration.response_max_tokens,
235                stop: None,
236                stream: true,
237            },
238            new_message,
239            request_estimated_tokens,
240        ))
241    }
242}
243
244#[derive(Serialize, Deserialize, Debug)]
245pub struct DataSource {
246    #[serde(rename = "type")]
247    pub data_type: String,
248    pub parameters: DataSourceParameters,
249}
250
251#[derive(Serialize, Deserialize, Debug)]
252pub struct DataSourceParameters {
253    pub endpoint: String,
254    pub authentication: DataSourceParametersAuthentication,
255    pub index_name: String,
256    pub query_type: String,
257    pub embedding_dependency: EmbeddingDependency,
258    pub in_scope: bool,
259    pub top_n_documents: i32,
260    pub strictness: i32,
261    #[serde(skip_serializing_if = "Option::is_none")]
262    pub filter: Option<String>,
263    pub fields_mapping: FieldsMapping,
264    pub semantic_configuration: String,
265}
266
267#[derive(Serialize, Deserialize, Debug)]
268pub struct DataSourceParametersAuthentication {
269    #[serde(rename = "type")]
270    pub auth_type: String,
271    pub key: String,
272}
273
274#[derive(Serialize, Deserialize, Debug)]
275pub struct EmbeddingDependency {
276    #[serde(rename = "type")]
277    pub dep_type: String,
278    pub deployment_name: String,
279}
280
281#[derive(Serialize, Deserialize, Debug)]
282pub struct FieldsMapping {
283    pub content_fields_separator: String,
284    pub content_fields: Vec<String>,
285    pub filepath_field: String,
286    pub title_field: String,
287    pub url_field: String,
288    pub vector_fields: Vec<String>,
289}
290
291#[derive(Serialize, Deserialize, Debug)]
292pub struct ChatResponse {
293    pub text: String,
294}
295
296/// Custom stream that encapsulates both the response stream and the cancellation guard. Makes sure that the guard is always dropped when the stream is dropped.
297#[pin_project]
298struct GuardedStream<S> {
299    guard: RequestCancelledGuard,
300    #[pin]
301    stream: S,
302}
303
304impl<S> GuardedStream<S> {
305    fn new(guard: RequestCancelledGuard, stream: S) -> Self {
306        Self { guard, stream }
307    }
308}
309
310impl<S> Stream for GuardedStream<S>
311where
312    S: Stream<Item = anyhow::Result<Bytes>> + Send,
313{
314    type Item = S::Item;
315
316    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
317        let this = self.project();
318        this.stream.poll_next(cx)
319    }
320}
321
322struct RequestCancelledGuard {
323    response_message_id: Uuid,
324    received_string: Arc<Mutex<Vec<String>>>,
325    pool: PgPool,
326    done: Arc<AtomicBool>,
327    request_estimated_tokens: i32,
328}
329
330impl Drop for RequestCancelledGuard {
331    fn drop(&mut self) {
332        if self.done.load(atomic::Ordering::Relaxed) {
333            return;
334        }
335        warn!("Request was not cancelled. Cleaning up.");
336        let response_message_id = self.response_message_id;
337        let received_string = self.received_string.clone();
338        let pool = self.pool.clone();
339        let request_estimated_tokens = self.request_estimated_tokens;
340        tokio::spawn(async move {
341            info!("Verifying the received message has been handled");
342            let mut conn = pool.acquire().await.expect("Could not acquire connection");
343            let full_response_text = received_string.lock().await;
344            if full_response_text.is_empty() {
345                info!("No response received. Deleting the response message");
346                models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
347                    .await
348                    .expect("Could not delete response message");
349                return;
350            }
351            info!("Response received but not completed. Saving the text received so far.");
352            let full_response_as_string = full_response_text.join("");
353            let estimated_cost = estimate_tokens(&full_response_as_string);
354            info!(
355                "End of chatbot response stream. Estimated cost: {}. Response: {}",
356                estimated_cost, full_response_as_string
357            );
358
359            // Update with request_estimated_tokens + estimated_cost
360            models::chatbot_conversation_messages::update(
361                &mut conn,
362                response_message_id,
363                &full_response_as_string,
364                true,
365                request_estimated_tokens + estimated_cost,
366            )
367            .await
368            .expect("Could not update response message");
369        });
370    }
371}
372
373pub async fn send_chat_request_and_parse_stream(
374    conn: &mut PgConnection,
375    pool: PgPool,
376    app_config: &ApplicationConfiguration,
377    chatbot_configuration_id: Uuid,
378    conversation_id: Uuid,
379    message: &str,
380) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
381    let (chat_request, new_message, request_estimated_tokens) =
382        ChatRequest::build_and_insert_incoming_message_to_db(
383            conn,
384            chatbot_configuration_id,
385            conversation_id,
386            message,
387            app_config,
388        )
389        .await?;
390
391    let full_response_text = Arc::new(Mutex::new(Vec::new()));
392    let done = Arc::new(AtomicBool::new(false));
393
394    let azure_config = app_config
395        .azure_configuration
396        .as_ref()
397        .ok_or_else(|| anyhow::anyhow!("Azure configuration not found"))?;
398
399    let chatbot_config = azure_config
400        .chatbot_config
401        .as_ref()
402        .ok_or_else(|| anyhow::anyhow!("Chatbot configuration not found"))?;
403
404    let api_key = chatbot_config.api_key.clone();
405    let mut url = chatbot_config.api_endpoint.clone();
406
407    // Always set the API version so that we actually use the API that the code is written for
408    url.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
409
410    let headers = build_llm_headers(&api_key)?;
411
412    let response_order_number = new_message.order_number + 1;
413
414    let response_message = models::chatbot_conversation_messages::insert(
415        conn,
416        ChatbotConversationMessage {
417            id: Uuid::new_v4(),
418            created_at: Utc::now(),
419            updated_at: Utc::now(),
420            deleted_at: None,
421            conversation_id,
422            message: None,
423            is_from_chatbot: true,
424            message_is_complete: false,
425            used_tokens: request_estimated_tokens,
426            order_number: response_order_number,
427        },
428    )
429    .await?;
430
431    // Instantiate the guard before creating the stream.
432    let guard = RequestCancelledGuard {
433        response_message_id: response_message.id,
434        received_string: full_response_text.clone(),
435        pool: pool.clone(),
436        done: done.clone(),
437        request_estimated_tokens,
438    };
439
440    let request = REQWEST_CLIENT
441        .post(url)
442        .headers(headers)
443        .json(&chat_request)
444        .send();
445
446    let response = request.await?;
447
448    info!("Receiving chat response with {:?}", response.version());
449
450    if !response.status().is_success() {
451        let status = response.status();
452        let error_message = response.text().await?;
453        return Err(anyhow::anyhow!(
454            "Failed to send chat request. Status: {}. Error: {}",
455            status,
456            error_message
457        ));
458    }
459
460    let stream = response.bytes_stream().map_err(std::io::Error::other);
461    let reader = StreamReader::new(stream);
462    let mut lines = reader.lines();
463
464    let response_stream = async_stream::try_stream! {
465        while let Some(line) = lines.next_line().await? {
466            if !line.starts_with("data: ") {
467                continue;
468            }
469            let json_str = line.trim_start_matches("data: ");
470
471            let mut full_response_text = full_response_text.lock().await;
472            if json_str.trim() == "[DONE]" {
473                let full_response_as_string = full_response_text.join("");
474                let estimated_cost = estimate_tokens(&full_response_as_string);
475                info!(
476                    "End of chatbot response stream. Estimated cost: {}. Response: {}",
477                    estimated_cost, full_response_as_string
478                );
479                done.store(true, atomic::Ordering::Relaxed);
480                let mut conn = pool.acquire().await?;
481                models::chatbot_conversation_messages::update(
482                    &mut conn,
483                    response_message.id,
484                    &full_response_as_string,
485                    true,
486                    request_estimated_tokens + estimated_cost,
487                ).await?;
488                break;
489            }
490            let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
491                anyhow::anyhow!("Failed to parse response chunk: {}", e)
492            })?;
493
494            for choice in &response_chunk.choices {
495                if let Some(delta) = &choice.delta {
496                    if let Some(content) = &delta.content {
497                        full_response_text.push(content.clone());
498                        let response = ChatResponse { text: content.clone() };
499                        let response_as_string = serde_json::to_string(&response)?;
500                        yield Bytes::from(response_as_string);
501                        yield Bytes::from("\n");
502                    }
503                    if let Some(context) = &delta.context {
504                        let citation_message_id = response_message.id;
505                        let mut conn = pool.acquire().await?;
506                        for (idx, cit) in context.citations.iter().enumerate() {
507                            let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
508                            let split = content.split_once(CONTENT_FIELD_SEPARATOR);
509                            if split.is_none() {
510                                error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
511                            }
512                            let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
513
514                            let document_url = cit.url.clone();
515                            let mut page_path = PathBuf::from(&cit.filepath);
516                            page_path.set_extension("");
517                            let page_id_str = page_path.file_name();
518                            let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
519                            let course_material_chapter_number = if let Some(id) = page_id {
520                                let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
521                                chapter.map(|c| c.chapter_number)
522                            } else {
523                                None
524                            };
525
526                            models::chatbot_conversation_messages_citations::insert(
527                                &mut conn, ChatbotConversationMessageCitation {
528                                    id: Uuid::new_v4(),
529                                    created_at: Utc::now(),
530                                    updated_at: Utc::now(),
531                                    deleted_at: None,
532                                    conversation_message_id: citation_message_id,
533                                    conversation_id,
534                                    course_material_chapter_number,
535                                    title: cit.title.clone(),
536                                    content: cleaned_content,
537                                    document_url,
538                                    citation_number: (idx+1) as i32,
539                                }
540                            ).await?;
541                        }
542                    }
543
544                }
545            }
546        }
547
548        if !done.load(atomic::Ordering::Relaxed) {
549            Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
550        }
551    };
552
553    // Encapsulate the stream and the guard within GuardedStream. This moves the request guard into the stream and ensures that it is dropped when the stream is dropped.
554    // This way we do cleanup only when the stream is dropped and not when this function returns.
555    let guarded_stream = GuardedStream::new(guard, response_stream);
556
557    // Box and pin the GuardedStream to satisfy the Unpin requirement
558    Ok(Box::pin(guarded_stream))
559}