headless_lms_chatbot/
azure_chatbot.rs

1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4    Arc,
5    atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use chrono::Utc;
11use futures::{Stream, TryStreamExt};
12use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
13use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
14use headless_lms_utils::{ApplicationConfiguration, http::REQWEST_CLIENT};
15use pin_project::pin_project;
16use serde::{Deserialize, Serialize};
17use sqlx::PgPool;
18use tokio::{io::AsyncBufReadExt, sync::Mutex};
19use tokio_util::io::StreamReader;
20use url::Url;
21
22use crate::llm_utils::{LLM_API_VERSION, build_llm_headers, estimate_tokens};
23use crate::prelude::*;
24use crate::search_filter::SearchFilter;
25
26#[derive(Deserialize, Serialize, Debug)]
27pub struct ContentFilterResults {
28    pub hate: Option<ContentFilter>,
29    pub self_harm: Option<ContentFilter>,
30    pub sexual: Option<ContentFilter>,
31    pub violence: Option<ContentFilter>,
32}
33
34#[derive(Deserialize, Serialize, Debug)]
35pub struct ContentFilter {
36    pub filtered: bool,
37    pub severity: String,
38}
39
40#[derive(Deserialize, Serialize, Debug)]
41pub struct Choice {
42    pub content_filter_results: Option<ContentFilterResults>,
43    pub delta: Option<Delta>,
44    pub finish_reason: Option<String>,
45    pub index: i32,
46}
47
48#[derive(Deserialize, Serialize, Debug)]
49pub struct Delta {
50    pub content: Option<String>,
51    pub context: Option<DeltaContext>,
52}
53
54#[derive(Deserialize, Serialize, Debug)]
55pub struct DeltaContext {
56    pub citations: Vec<Citation>,
57}
58
59#[derive(Deserialize, Serialize, Debug)]
60pub struct Citation {
61    pub content: String,
62    pub title: String,
63    pub url: String,
64    pub filepath: String,
65}
66
67#[derive(Deserialize, Serialize, Debug)]
68pub struct ResponseChunk {
69    pub choices: Vec<Choice>,
70    pub created: u64,
71    pub id: String,
72    pub model: String,
73    pub object: String,
74    pub system_fingerprint: Option<String>,
75}
76
77/// The format accepted by the API.
78#[derive(Serialize, Deserialize, Debug)]
79pub struct ApiChatMessage {
80    pub role: String,
81    pub content: String,
82}
83
84impl From<ChatbotConversationMessage> for ApiChatMessage {
85    fn from(message: ChatbotConversationMessage) -> Self {
86        ApiChatMessage {
87            role: if message.is_from_chatbot {
88                "assistant".to_string()
89            } else {
90                "user".to_string()
91            },
92            content: message.message.unwrap_or_default(),
93        }
94    }
95}
96
97#[derive(Serialize, Deserialize, Debug)]
98pub struct ChatRequest {
99    pub messages: Vec<ApiChatMessage>,
100    #[serde(skip_serializing_if = "Vec::is_empty")]
101    pub data_sources: Vec<DataSource>,
102    pub temperature: f32,
103    pub top_p: f32,
104    pub frequency_penalty: f32,
105    pub presence_penalty: f32,
106    pub max_tokens: i32,
107    pub stop: Option<String>,
108    pub stream: bool,
109}
110
111impl ChatRequest {
112    pub async fn build_and_insert_incoming_message_to_db(
113        conn: &mut PgConnection,
114        chatbot_configuration_id: Uuid,
115        conversation_id: Uuid,
116        message: &str,
117        app_config: &ApplicationConfiguration,
118    ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
119        let index_name = Url::parse(&app_config.base_url)?
120            .host_str()
121            .expect("BASE_URL must have a host")
122            .replace(".", "-");
123
124        let configuration =
125            models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
126
127        let conversation_messages =
128            models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
129                .await?;
130
131        let new_order_number = conversation_messages
132            .iter()
133            .map(|m| m.order_number)
134            .max()
135            .unwrap_or(0)
136            + 1;
137
138        let new_message = models::chatbot_conversation_messages::insert(
139            conn,
140            ChatbotConversationMessage {
141                id: Uuid::new_v4(),
142                created_at: Utc::now(),
143                updated_at: Utc::now(),
144                deleted_at: None,
145                conversation_id,
146                message: Some(message.to_string()),
147                is_from_chatbot: false,
148                message_is_complete: true,
149                used_tokens: estimate_tokens(message),
150                order_number: new_order_number,
151            },
152        )
153        .await?;
154
155        let mut api_chat_messages: Vec<ApiChatMessage> =
156            conversation_messages.into_iter().map(Into::into).collect();
157
158        api_chat_messages.push(new_message.clone().into());
159
160        api_chat_messages.insert(
161            0,
162            ApiChatMessage {
163                role: "system".to_string(),
164                content: configuration.prompt.clone(),
165            },
166        );
167
168        let data_sources = if configuration.use_azure_search {
169            let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
170                anyhow::anyhow!("Azure configuration is missing from the application configuration")
171            })?;
172
173            let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
174                anyhow::anyhow!(
175                    "Azure search configuration is missing from the Azure configuration"
176                )
177            })?;
178
179            let query_type = if configuration.use_semantic_reranking {
180                "vector_semantic_hybrid"
181            } else {
182                "vector_simple_hybrid"
183            };
184
185            vec![DataSource {
186                data_type: "azure_search".to_string(),
187                parameters: DataSourceParameters {
188                    endpoint: search_config.search_endpoint.to_string(),
189                    authentication: DataSourceParametersAuthentication {
190                        auth_type: "api_key".to_string(),
191                        key: search_config.search_api_key.clone(),
192                    },
193                    index_name,
194                    query_type: query_type.to_string(),
195                    semantic_configuration: "default".to_string(),
196                    embedding_dependency: EmbeddingDependency {
197                        dep_type: "deployment_name".to_string(),
198                        deployment_name: search_config.vectorizer_deployment_id.clone(),
199                    },
200                    in_scope: false,
201                    top_n_documents: 5,
202                    strictness: 3,
203                    filter: Some(
204                        SearchFilter::eq("course_id", configuration.course_id.to_string())
205                            .to_odata()?,
206                    ),
207                    fields_mapping: FieldsMapping {
208                        content_fields_separator: ",".to_string(),
209                        content_fields: vec!["chunk".to_string()],
210                        filepath_field: "filepath".to_string(),
211                        title_field: "title".to_string(),
212                        url_field: "url".to_string(),
213                        vector_fields: vec!["text_vector".to_string()],
214                    },
215                },
216            }]
217        } else {
218            Vec::new()
219        };
220
221        let serialized_messages = serde_json::to_string(&api_chat_messages)?;
222        let request_estimated_tokens = estimate_tokens(&serialized_messages);
223
224        Ok((
225            Self {
226                messages: api_chat_messages,
227                data_sources,
228                temperature: configuration.temperature,
229                top_p: configuration.top_p,
230                frequency_penalty: configuration.frequency_penalty,
231                presence_penalty: configuration.presence_penalty,
232                max_tokens: configuration.response_max_tokens,
233                stop: None,
234                stream: true,
235            },
236            new_message,
237            request_estimated_tokens,
238        ))
239    }
240}
241
242#[derive(Serialize, Deserialize, Debug)]
243pub struct DataSource {
244    #[serde(rename = "type")]
245    pub data_type: String,
246    pub parameters: DataSourceParameters,
247}
248
249#[derive(Serialize, Deserialize, Debug)]
250pub struct DataSourceParameters {
251    pub endpoint: String,
252    pub authentication: DataSourceParametersAuthentication,
253    pub index_name: String,
254    pub query_type: String,
255    pub embedding_dependency: EmbeddingDependency,
256    pub in_scope: bool,
257    pub top_n_documents: i32,
258    pub strictness: i32,
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub filter: Option<String>,
261    pub fields_mapping: FieldsMapping,
262    pub semantic_configuration: String,
263}
264
265#[derive(Serialize, Deserialize, Debug)]
266pub struct DataSourceParametersAuthentication {
267    #[serde(rename = "type")]
268    pub auth_type: String,
269    pub key: String,
270}
271
272#[derive(Serialize, Deserialize, Debug)]
273pub struct EmbeddingDependency {
274    #[serde(rename = "type")]
275    pub dep_type: String,
276    pub deployment_name: String,
277}
278
279#[derive(Serialize, Deserialize, Debug)]
280pub struct FieldsMapping {
281    pub content_fields_separator: String,
282    pub content_fields: Vec<String>,
283    pub filepath_field: String,
284    pub title_field: String,
285    pub url_field: String,
286    pub vector_fields: Vec<String>,
287}
288
289#[derive(Serialize, Deserialize, Debug)]
290pub struct ChatResponse {
291    pub text: String,
292}
293
294/// Custom stream that encapsulates both the response stream and the cancellation guard. Makes sure that the guard is always dropped when the stream is dropped.
295#[pin_project]
296struct GuardedStream<S> {
297    guard: RequestCancelledGuard,
298    #[pin]
299    stream: S,
300}
301
302impl<S> GuardedStream<S> {
303    fn new(guard: RequestCancelledGuard, stream: S) -> Self {
304        Self { guard, stream }
305    }
306}
307
308impl<S> Stream for GuardedStream<S>
309where
310    S: Stream<Item = anyhow::Result<Bytes>> + Send,
311{
312    type Item = S::Item;
313
314    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
315        let this = self.project();
316        this.stream.poll_next(cx)
317    }
318}
319
320struct RequestCancelledGuard {
321    response_message_id: Uuid,
322    received_string: Arc<Mutex<Vec<String>>>,
323    pool: PgPool,
324    done: Arc<AtomicBool>,
325    request_estimated_tokens: i32,
326}
327
328impl Drop for RequestCancelledGuard {
329    fn drop(&mut self) {
330        if self.done.load(atomic::Ordering::Relaxed) {
331            return;
332        }
333        warn!("Request was not cancelled. Cleaning up.");
334        let response_message_id = self.response_message_id;
335        let received_string = self.received_string.clone();
336        let pool = self.pool.clone();
337        let request_estimated_tokens = self.request_estimated_tokens;
338        tokio::spawn(async move {
339            info!("Verifying the received message has been handled");
340            let mut conn = pool.acquire().await.expect("Could not acquire connection");
341            let full_response_text = received_string.lock().await;
342            if full_response_text.is_empty() {
343                info!("No response received. Deleting the response message");
344                models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
345                    .await
346                    .expect("Could not delete response message");
347                return;
348            }
349            info!("Response received but not completed. Saving the text received so far.");
350            let full_response_as_string = full_response_text.join("");
351            let estimated_cost = estimate_tokens(&full_response_as_string);
352            info!(
353                "End of chatbot response stream. Estimated cost: {}. Response: {}",
354                estimated_cost, full_response_as_string
355            );
356
357            // Update with request_estimated_tokens + estimated_cost
358            models::chatbot_conversation_messages::update(
359                &mut conn,
360                response_message_id,
361                &full_response_as_string,
362                true,
363                request_estimated_tokens + estimated_cost,
364            )
365            .await
366            .expect("Could not update response message");
367        });
368    }
369}
370
371pub async fn send_chat_request_and_parse_stream(
372    conn: &mut PgConnection,
373    pool: PgPool,
374    app_config: &ApplicationConfiguration,
375    chatbot_configuration_id: Uuid,
376    conversation_id: Uuid,
377    message: &str,
378) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
379    let (chat_request, new_message, request_estimated_tokens) =
380        ChatRequest::build_and_insert_incoming_message_to_db(
381            conn,
382            chatbot_configuration_id,
383            conversation_id,
384            message,
385            app_config,
386        )
387        .await?;
388
389    let full_response_text = Arc::new(Mutex::new(Vec::new()));
390    let done = Arc::new(AtomicBool::new(false));
391
392    let azure_config = app_config
393        .azure_configuration
394        .as_ref()
395        .ok_or_else(|| anyhow::anyhow!("Azure configuration not found"))?;
396
397    let chatbot_config = azure_config
398        .chatbot_config
399        .as_ref()
400        .ok_or_else(|| anyhow::anyhow!("Chatbot configuration not found"))?;
401
402    let api_key = chatbot_config.api_key.clone();
403    let mut url = chatbot_config.api_endpoint.clone();
404
405    // Always set the API version so that we actually use the API that the code is written for
406    url.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
407
408    let headers = build_llm_headers(&api_key)?;
409
410    let response_order_number = new_message.order_number + 1;
411
412    let response_message = models::chatbot_conversation_messages::insert(
413        conn,
414        ChatbotConversationMessage {
415            id: Uuid::new_v4(),
416            created_at: Utc::now(),
417            updated_at: Utc::now(),
418            deleted_at: None,
419            conversation_id,
420            message: None,
421            is_from_chatbot: true,
422            message_is_complete: false,
423            used_tokens: request_estimated_tokens,
424            order_number: response_order_number,
425        },
426    )
427    .await?;
428
429    // Instantiate the guard before creating the stream.
430    let guard = RequestCancelledGuard {
431        response_message_id: response_message.id,
432        received_string: full_response_text.clone(),
433        pool: pool.clone(),
434        done: done.clone(),
435        request_estimated_tokens,
436    };
437
438    let request = REQWEST_CLIENT
439        .post(url)
440        .headers(headers)
441        .json(&chat_request)
442        .send();
443
444    let response = request.await?;
445
446    info!("Receiving chat response with {:?}", response.version());
447
448    if !response.status().is_success() {
449        let status = response.status();
450        let error_message = response.text().await?;
451        return Err(anyhow::anyhow!(
452            "Failed to send chat request. Status: {}. Error: {}",
453            status,
454            error_message
455        ));
456    }
457
458    let stream = response.bytes_stream().map_err(std::io::Error::other);
459    let reader = StreamReader::new(stream);
460    let mut lines = reader.lines();
461
462    let response_stream = async_stream::try_stream! {
463        while let Some(line) = lines.next_line().await? {
464            if !line.starts_with("data: ") {
465                continue;
466            }
467            let json_str = line.trim_start_matches("data: ");
468
469            let mut full_response_text = full_response_text.lock().await;
470            if json_str.trim() == "[DONE]" {
471                let full_response_as_string = full_response_text.join("");
472                let estimated_cost = estimate_tokens(&full_response_as_string);
473                info!(
474                    "End of chatbot response stream. Estimated cost: {}. Response: {}",
475                    estimated_cost, full_response_as_string
476                );
477                done.store(true, atomic::Ordering::Relaxed);
478                let mut conn = pool.acquire().await?;
479                models::chatbot_conversation_messages::update(
480                    &mut conn,
481                    response_message.id,
482                    &full_response_as_string,
483                    true,
484                    request_estimated_tokens + estimated_cost,
485                ).await?;
486                break;
487            }
488            let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
489                anyhow::anyhow!("Failed to parse response chunk: {}", e)
490            })?;
491
492            for choice in &response_chunk.choices {
493                if let Some(delta) = &choice.delta {
494                    if let Some(content) = &delta.content {
495                        full_response_text.push(content.clone());
496                        let response = ChatResponse { text: content.clone() };
497                        let response_as_string = serde_json::to_string(&response)?;
498                        yield Bytes::from(response_as_string);
499                        yield Bytes::from("\n");
500                    }
501                    if let Some(context) = &delta.context {
502                        let citation_message_id = response_message.id;
503                        let mut conn = pool.acquire().await?;
504                        for (idx, cit) in context.citations.iter().enumerate() {
505                            let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
506                            let document_url = cit.url.clone();
507                            let mut page_path = PathBuf::from(&cit.filepath);
508                            page_path.set_extension("");
509                            let page_id_str = page_path.file_name();
510                            let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
511                            let course_material_chapter_number = if let Some(id) = page_id {
512                                let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
513                                chapter.map(|c| c.chapter_number)
514                            } else {
515                                None
516                            };
517
518                            models::chatbot_conversation_messages_citations::insert(
519                                &mut conn, ChatbotConversationMessageCitation {
520                                    id: Uuid::new_v4(),
521                                    created_at: Utc::now(),
522                                    updated_at: Utc::now(),
523                                    deleted_at: None,
524                                    conversation_message_id: citation_message_id,
525                                    conversation_id,
526                                    course_material_chapter_number,
527                                    title: cit.title.clone(),
528                                    content,
529                                    document_url,
530                                    citation_number: (idx+1) as i32,
531                                }
532                            ).await?;
533                        }
534                    }
535
536                }
537            }
538        }
539
540        if !done.load(atomic::Ordering::Relaxed) {
541            Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
542        }
543    };
544
545    // Encapsulate the stream and the guard within GuardedStream. This moves the request guard into the stream and ensures that it is dropped when the stream is dropped.
546    // This way we do cleanup only when the stream is dropped and not when this function returns.
547    let guarded_stream = GuardedStream::new(guard, response_stream);
548
549    // Box and pin the GuardedStream to satisfy the Unpin requirement
550    Ok(Box::pin(guarded_stream))
551}