headless_lms_chatbot/
azure_chatbot.rs

1use std::pin::Pin;
2use std::sync::{
3    Arc,
4    atomic::{self, AtomicBool},
5};
6use std::task::{Context, Poll};
7
8use bytes::Bytes;
9use chrono::Utc;
10use futures::{Stream, TryStreamExt};
11use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
12use headless_lms_utils::{ApplicationConfiguration, http::REQWEST_CLIENT};
13use pin_project::pin_project;
14use serde::{Deserialize, Serialize};
15use sqlx::PgPool;
16use tokio::{io::AsyncBufReadExt, sync::Mutex};
17use tokio_util::io::StreamReader;
18use url::Url;
19
20use crate::llm_utils::{LLM_API_VERSION, build_llm_headers, estimate_tokens};
21use crate::prelude::*;
22use crate::search_filter::SearchFilter;
23
24#[derive(Deserialize, Serialize, Debug)]
25pub struct ContentFilterResults {
26    pub hate: Option<ContentFilter>,
27    pub self_harm: Option<ContentFilter>,
28    pub sexual: Option<ContentFilter>,
29    pub violence: Option<ContentFilter>,
30}
31
32#[derive(Deserialize, Serialize, Debug)]
33pub struct ContentFilter {
34    pub filtered: bool,
35    pub severity: String,
36}
37
38#[derive(Deserialize, Serialize, Debug)]
39pub struct Choice {
40    pub content_filter_results: Option<ContentFilterResults>,
41    pub delta: Option<Delta>,
42    pub finish_reason: Option<String>,
43    pub index: i32,
44}
45
46#[derive(Deserialize, Serialize, Debug)]
47pub struct Delta {
48    pub content: Option<String>,
49}
50
51#[derive(Deserialize, Serialize, Debug)]
52pub struct ResponseChunk {
53    pub choices: Vec<Choice>,
54    pub created: u64,
55    pub id: String,
56    pub model: String,
57    pub object: String,
58    pub system_fingerprint: Option<String>,
59}
60
61/// The format accepted by the API.
62#[derive(Serialize, Deserialize, Debug)]
63pub struct ApiChatMessage {
64    pub role: String,
65    pub content: String,
66}
67
68impl From<ChatbotConversationMessage> for ApiChatMessage {
69    fn from(message: ChatbotConversationMessage) -> Self {
70        ApiChatMessage {
71            role: if message.is_from_chatbot {
72                "assistant".to_string()
73            } else {
74                "user".to_string()
75            },
76            content: message.message.unwrap_or_default(),
77        }
78    }
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82pub struct ChatRequest {
83    pub messages: Vec<ApiChatMessage>,
84    #[serde(skip_serializing_if = "Vec::is_empty")]
85    pub data_sources: Vec<DataSource>,
86    pub temperature: f32,
87    pub top_p: f32,
88    pub frequency_penalty: f32,
89    pub presence_penalty: f32,
90    pub max_tokens: i32,
91    pub stop: Option<String>,
92    pub stream: bool,
93}
94
95impl ChatRequest {
96    pub async fn build_and_insert_incoming_message_to_db(
97        conn: &mut PgConnection,
98        chatbot_configuration_id: Uuid,
99        conversation_id: Uuid,
100        message: &str,
101        app_config: &ApplicationConfiguration,
102    ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
103        let index_name = Url::parse(&app_config.base_url)?
104            .host_str()
105            .expect("BASE_URL must have a host")
106            .replace(".", "-");
107
108        let configuration =
109            models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
110
111        let conversation_messages =
112            models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
113                .await?;
114
115        let new_order_number = conversation_messages
116            .iter()
117            .map(|m| m.order_number)
118            .max()
119            .unwrap_or(0)
120            + 1;
121
122        let new_message = models::chatbot_conversation_messages::insert(
123            conn,
124            ChatbotConversationMessage {
125                id: Uuid::new_v4(),
126                created_at: Utc::now(),
127                updated_at: Utc::now(),
128                deleted_at: None,
129                conversation_id,
130                message: Some(message.to_string()),
131                is_from_chatbot: false,
132                message_is_complete: true,
133                used_tokens: estimate_tokens(message),
134                order_number: new_order_number,
135            },
136        )
137        .await?;
138
139        let mut api_chat_messages: Vec<ApiChatMessage> =
140            conversation_messages.into_iter().map(Into::into).collect();
141
142        api_chat_messages.push(new_message.clone().into());
143
144        api_chat_messages.insert(
145            0,
146            ApiChatMessage {
147                role: "system".to_string(),
148                content: configuration.prompt.clone(),
149            },
150        );
151
152        let data_sources = if configuration.use_azure_search {
153            let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
154                anyhow::anyhow!("Azure configuration is missing from the application configuration")
155            })?;
156
157            let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
158                anyhow::anyhow!(
159                    "Azure search configuration is missing from the Azure configuration"
160                )
161            })?;
162
163            let query_type = if configuration.use_semantic_reranking {
164                "vector_semantic_hybrid"
165            } else {
166                "vector_simple_hybrid"
167            };
168
169            vec![DataSource {
170                data_type: "azure_search".to_string(),
171                parameters: DataSourceParameters {
172                    endpoint: search_config.search_endpoint.to_string(),
173                    authentication: DataSourceParametersAuthentication {
174                        auth_type: "api_key".to_string(),
175                        key: search_config.search_api_key.clone(),
176                    },
177                    index_name,
178                    query_type: query_type.to_string(),
179                    semantic_configuration: "default".to_string(),
180                    embedding_dependency: EmbeddingDependency {
181                        dep_type: "deployment_name".to_string(),
182                        deployment_name: search_config.vectorizer_deployment_id.clone(),
183                    },
184                    in_scope: false,
185                    top_n_documents: 5,
186                    strictness: 3,
187                    filter: Some(
188                        SearchFilter::eq("course_id", configuration.course_id.to_string())
189                            .to_odata()?,
190                    ),
191                    fields_mapping: FieldsMapping {
192                        content_fields_separator: ",".to_string(),
193                        content_fields: vec!["chunk".to_string()],
194                        filepath_field: "filepath".to_string(),
195                        title_field: "title".to_string(),
196                        url_field: "url".to_string(),
197                        vector_fields: vec!["text_vector".to_string()],
198                    },
199                },
200            }]
201        } else {
202            Vec::new()
203        };
204
205        let serialized_messages = serde_json::to_string(&api_chat_messages)?;
206        let request_estimated_tokens = estimate_tokens(&serialized_messages);
207
208        Ok((
209            Self {
210                messages: api_chat_messages,
211                data_sources,
212                temperature: configuration.temperature,
213                top_p: configuration.top_p,
214                frequency_penalty: configuration.frequency_penalty,
215                presence_penalty: configuration.presence_penalty,
216                max_tokens: configuration.response_max_tokens,
217                stop: None,
218                stream: true,
219            },
220            new_message,
221            request_estimated_tokens,
222        ))
223    }
224}
225
226#[derive(Serialize, Deserialize, Debug)]
227pub struct DataSource {
228    #[serde(rename = "type")]
229    pub data_type: String,
230    pub parameters: DataSourceParameters,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct DataSourceParameters {
235    pub endpoint: String,
236    pub authentication: DataSourceParametersAuthentication,
237    pub index_name: String,
238    pub query_type: String,
239    pub embedding_dependency: EmbeddingDependency,
240    pub in_scope: bool,
241    pub top_n_documents: i32,
242    pub strictness: i32,
243    #[serde(skip_serializing_if = "Option::is_none")]
244    pub filter: Option<String>,
245    pub fields_mapping: FieldsMapping,
246    pub semantic_configuration: String,
247}
248
249#[derive(Serialize, Deserialize, Debug)]
250pub struct DataSourceParametersAuthentication {
251    #[serde(rename = "type")]
252    pub auth_type: String,
253    pub key: String,
254}
255
256#[derive(Serialize, Deserialize, Debug)]
257pub struct EmbeddingDependency {
258    #[serde(rename = "type")]
259    pub dep_type: String,
260    pub deployment_name: String,
261}
262
263#[derive(Serialize, Deserialize, Debug)]
264pub struct FieldsMapping {
265    pub content_fields_separator: String,
266    pub content_fields: Vec<String>,
267    pub filepath_field: String,
268    pub title_field: String,
269    pub url_field: String,
270    pub vector_fields: Vec<String>,
271}
272
273#[derive(Serialize, Deserialize, Debug)]
274pub struct ChatResponse {
275    pub text: String,
276}
277
278/// Custom stream that encapsulates both the response stream and the cancellation guard. Makes sure that the guard is always dropped when the stream is dropped.
279#[pin_project]
280struct GuardedStream<S> {
281    guard: RequestCancelledGuard,
282    #[pin]
283    stream: S,
284}
285
286impl<S> GuardedStream<S> {
287    fn new(guard: RequestCancelledGuard, stream: S) -> Self {
288        Self { guard, stream }
289    }
290}
291
292impl<S> Stream for GuardedStream<S>
293where
294    S: Stream<Item = anyhow::Result<Bytes>> + Send,
295{
296    type Item = S::Item;
297
298    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
299        let this = self.project();
300        this.stream.poll_next(cx)
301    }
302}
303
304struct RequestCancelledGuard {
305    response_message_id: Uuid,
306    received_string: Arc<Mutex<Vec<String>>>,
307    pool: PgPool,
308    done: Arc<AtomicBool>,
309    request_estimated_tokens: i32,
310}
311
312impl Drop for RequestCancelledGuard {
313    fn drop(&mut self) {
314        if self.done.load(atomic::Ordering::Relaxed) {
315            return;
316        }
317        warn!("Request was not cancelled. Cleaning up.");
318        let response_message_id = self.response_message_id;
319        let received_string = self.received_string.clone();
320        let pool = self.pool.clone();
321        let request_estimated_tokens = self.request_estimated_tokens;
322        tokio::spawn(async move {
323            info!("Verifying the received message has been handled");
324            let mut conn = pool.acquire().await.expect("Could not acquire connection");
325            let full_response_text = received_string.lock().await;
326            if full_response_text.is_empty() {
327                info!("No response received. Deleting the response message");
328                models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
329                    .await
330                    .expect("Could not delete response message");
331                return;
332            }
333            info!("Response received but not completed. Saving the text received so far.");
334            let full_response_as_string = full_response_text.join("");
335            let estimated_cost = estimate_tokens(&full_response_as_string);
336            info!(
337                "End of chatbot response stream. Estimated cost: {}. Response: {}",
338                estimated_cost, full_response_as_string
339            );
340
341            // Update with request_estimated_tokens + estimated_cost
342            models::chatbot_conversation_messages::update(
343                &mut conn,
344                response_message_id,
345                &full_response_as_string,
346                true,
347                request_estimated_tokens + estimated_cost,
348            )
349            .await
350            .expect("Could not update response message");
351        });
352    }
353}
354
355pub async fn send_chat_request_and_parse_stream(
356    conn: &mut PgConnection,
357    pool: PgPool,
358    app_config: &ApplicationConfiguration,
359    chatbot_configuration_id: Uuid,
360    conversation_id: Uuid,
361    message: &str,
362) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
363    let (chat_request, new_message, request_estimated_tokens) =
364        ChatRequest::build_and_insert_incoming_message_to_db(
365            conn,
366            chatbot_configuration_id,
367            conversation_id,
368            message,
369            app_config,
370        )
371        .await?;
372
373    let full_response_text = Arc::new(Mutex::new(Vec::new()));
374    let done = Arc::new(AtomicBool::new(false));
375
376    let azure_config = app_config
377        .azure_configuration
378        .as_ref()
379        .ok_or_else(|| anyhow::anyhow!("Azure configuration not found"))?;
380
381    let chatbot_config = azure_config
382        .chatbot_config
383        .as_ref()
384        .ok_or_else(|| anyhow::anyhow!("Chatbot configuration not found"))?;
385
386    let api_key = chatbot_config.api_key.clone();
387    let mut url = chatbot_config.api_endpoint.clone();
388
389    // Always set the API version so that we actually use the API that the code is written for
390    url.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
391
392    let headers = build_llm_headers(&api_key)?;
393
394    let response_order_number = new_message.order_number + 1;
395
396    let response_message = models::chatbot_conversation_messages::insert(
397        conn,
398        ChatbotConversationMessage {
399            id: Uuid::new_v4(),
400            created_at: Utc::now(),
401            updated_at: Utc::now(),
402            deleted_at: None,
403            conversation_id,
404            message: None,
405            is_from_chatbot: true,
406            message_is_complete: false,
407            used_tokens: request_estimated_tokens,
408            order_number: response_order_number,
409        },
410    )
411    .await?;
412
413    // Instantiate the guard before creating the stream.
414    let guard = RequestCancelledGuard {
415        response_message_id: response_message.id,
416        received_string: full_response_text.clone(),
417        pool: pool.clone(),
418        done: done.clone(),
419        request_estimated_tokens,
420    };
421
422    let request = REQWEST_CLIENT
423        .post(url)
424        .headers(headers)
425        .json(&chat_request)
426        .send();
427
428    let response = request.await?;
429
430    info!("Receiving chat response with {:?}", response.version());
431
432    if !response.status().is_success() {
433        let status = response.status();
434        let error_message = response.text().await?;
435        return Err(anyhow::anyhow!(
436            "Failed to send chat request. Status: {}. Error: {}",
437            status,
438            error_message
439        ));
440    }
441
442    let stream = response.bytes_stream().map_err(std::io::Error::other);
443    let reader = StreamReader::new(stream);
444    let mut lines = reader.lines();
445
446    let response_stream = async_stream::try_stream! {
447        while let Some(line) = lines.next_line().await? {
448            if !line.starts_with("data: ") {
449                continue;
450            }
451            let json_str = line.trim_start_matches("data: ");
452            let mut full_response_text = full_response_text.lock().await;
453            if json_str.trim() == "[DONE]" {
454                let full_response_as_string = full_response_text.join("");
455                let estimated_cost = estimate_tokens(&full_response_as_string);
456                info!(
457                    "End of chatbot response stream. Estimated cost: {}. Response: {}",
458                    estimated_cost, full_response_as_string
459                );
460                done.store(true, atomic::Ordering::Relaxed);
461                let mut conn = pool.acquire().await?;
462                models::chatbot_conversation_messages::update(
463                    &mut conn,
464                    response_message.id,
465                    &full_response_as_string,
466                    true,
467                    request_estimated_tokens + estimated_cost,
468                ).await?;
469                break;
470            }
471            let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
472                anyhow::anyhow!("Failed to parse response chunk: {}", e)
473            })?;
474            for choice in &response_chunk.choices {
475                if let Some(delta) = &choice.delta {
476                    if let Some(content) = &delta.content {
477                        full_response_text.push(content.clone());
478                        let response = ChatResponse { text: content.clone() };
479                        let response_as_string = serde_json::to_string(&response)?;
480                        yield Bytes::from(response_as_string);
481                        yield Bytes::from("\n");
482                    }
483
484                }
485            }
486        }
487
488        if !done.load(atomic::Ordering::Relaxed) {
489            Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
490        }
491    };
492
493    // Encapsulate the stream and the guard within GuardedStream. This moves the request guard into the stream and ensures that it is dropped when the stream is dropped.
494    // This way we do cleanup only when the stream is dropped and not when this function returns.
495    let guarded_stream = GuardedStream::new(guard, response_stream);
496
497    // Box and pin the GuardedStream to satisfy the Unpin requirement
498    Ok(Box::pin(guarded_stream))
499}