headless_lms_chatbot/
azure_chatbot.rs

1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4    Arc,
5    atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use anyhow::{Error, Ok, anyhow};
10use bytes::Bytes;
11use chrono::Utc;
12use futures::stream::{BoxStream, Peekable};
13use futures::{Stream, StreamExt, TryStreamExt};
14use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
15use headless_lms_models::chatbot_conversation_messages::{
16    self, ChatbotConversationMessage, MessageRole,
17};
18use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
19use headless_lms_utils::ApplicationConfiguration;
20use pin_project::pin_project;
21use serde::{Deserialize, Serialize};
22use sqlx::PgPool;
23use tokio::{io::AsyncBufReadExt, sync::Mutex};
24use tokio_stream::wrappers::LinesStream;
25use tokio_util::io::StreamReader;
26use tracing::trace;
27use url::Url;
28
29use crate::chatbot_error::ChatbotResult;
30use crate::chatbot_tools::{
31    AzureLLMToolDefinition, ChatbotTool, get_chatbot_tool, get_chatbot_tool_definitions,
32};
33use crate::llm_utils::{
34    APIMessage, APIMessageKind, APIMessageText, APIMessageToolCall, APIMessageToolResponse,
35    APITool, APIToolCall, estimate_tokens, make_streaming_llm_request,
36};
37use headless_lms_utils::url_encoding::url_decode;
38
39use crate::prelude::*;
40use crate::search_filter::SearchFilter;
41
42const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
43
44/// Context about the user and course for a chatbot interaction.
45/// Passed to tool implementations so they can access user-specific data.
46pub struct ChatbotUserContext {
47    pub user_id: Uuid,
48    pub course_id: Uuid,
49    pub course_name: String,
50}
51
52#[derive(Deserialize, Serialize, Debug)]
53pub struct ContentFilterResults {
54    pub hate: Option<ContentFilter>,
55    pub self_harm: Option<ContentFilter>,
56    pub sexual: Option<ContentFilter>,
57    pub violence: Option<ContentFilter>,
58}
59
60#[derive(Deserialize, Serialize, Debug)]
61pub struct ContentFilter {
62    pub filtered: bool,
63    pub severity: String,
64}
65
66/// Data in a streamed response chunk
67#[derive(Deserialize, Serialize, Debug)]
68pub struct Choice {
69    pub content_filter_results: Option<ContentFilterResults>,
70    pub delta: Option<Delta>,
71    pub finish_reason: Option<String>,
72    pub index: i32,
73}
74
75/// Content in a streamed response chunk Choice
76#[derive(Deserialize, Serialize, Debug)]
77pub struct Delta {
78    pub content: Option<String>,
79    pub context: Option<DeltaContext>,
80    pub tool_calls: Option<Vec<ToolCallInDelta>>,
81}
82
83#[derive(Deserialize, Serialize, Debug)]
84pub struct DeltaContext {
85    pub citations: Vec<Citation>,
86}
87
88/// A streamed tool call from Azure
89#[derive(Deserialize, Serialize, Debug)]
90pub struct ToolCallInDelta {
91    pub id: Option<String>,
92    pub function: DeltaTool,
93    #[serde(rename = "type")]
94    pub tool_type: Option<ToolCallType>,
95}
96
97/// Streamed tool call content
98#[derive(Deserialize, Serialize, Debug, Clone)]
99pub struct DeltaTool {
100    #[serde(default)]
101    pub arguments: String,
102    pub name: Option<String>,
103}
104
105#[derive(Serialize, Deserialize, Debug, Clone)]
106#[serde(rename_all = "snake_case")]
107pub enum ToolCallType {
108    Function,
109}
110
111#[derive(Deserialize, Serialize, Debug)]
112pub struct Citation {
113    pub content: String,
114    pub title: String,
115    pub url: String,
116    pub filepath: String,
117}
118
119/// Response received from LLM API
120#[derive(Deserialize, Serialize, Debug)]
121pub struct ResponseChunk {
122    pub choices: Vec<Choice>,
123    pub created: u64,
124    pub id: String,
125    pub model: String,
126    pub object: String,
127    pub system_fingerprint: Option<String>,
128}
129
130#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
131#[serde(rename_all = "snake_case")]
132pub enum LLMToolChoice {
133    Auto,
134}
135
136#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
137pub struct ThinkingParams {
138    pub max_completion_tokens: Option<i32>,
139    pub verbosity: Option<VerbosityLevel>,
140    pub reasoning_effort: Option<ReasoningEffortLevel>,
141    #[serde(skip_serializing_if = "Vec::is_empty")]
142    pub tools: Vec<AzureLLMToolDefinition>,
143    pub tool_choice: Option<LLMToolChoice>,
144}
145
146#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
147pub struct NonThinkingParams {
148    pub max_tokens: Option<i32>,
149    pub temperature: Option<f32>,
150    pub top_p: Option<f32>,
151    pub frequency_penalty: Option<f32>,
152    pub presence_penalty: Option<f32>,
153}
154
155#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
156#[serde(untagged)]
157pub enum LLMRequestParams {
158    Thinking(ThinkingParams),
159    NonThinking(NonThinkingParams),
160}
161
162#[derive(Serialize, Deserialize, Debug, Clone)]
163pub struct LLMRequest {
164    pub messages: Vec<APIMessage>,
165    #[serde(skip_serializing_if = "Vec::is_empty")]
166    pub data_sources: Vec<DataSource>,
167    #[serde(flatten)]
168    pub params: LLMRequestParams,
169    pub stop: Option<String>,
170}
171
172impl LLMRequest {
173    pub async fn build_and_insert_incoming_message_to_db(
174        conn: &mut PgConnection,
175        chatbot_configuration_id: Uuid,
176        conversation_id: Uuid,
177        message: &str,
178        app_config: &ApplicationConfiguration,
179    ) -> anyhow::Result<(Self, i32, i32)> {
180        let index_name = Url::parse(&app_config.base_url)?
181            .host_str()
182            .expect("BASE_URL must have a host")
183            .replace(".", "-");
184
185        let configuration =
186            models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
187
188        let conversation_messages =
189            models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
190                .await?;
191
192        let new_order_number = conversation_messages
193            .iter()
194            .map(|m| m.order_number)
195            .max()
196            .unwrap_or(0)
197            + 1;
198
199        let new_message = models::chatbot_conversation_messages::insert(
200            conn,
201            ChatbotConversationMessage {
202                id: Uuid::new_v4(),
203                created_at: Utc::now(),
204                updated_at: Utc::now(),
205                deleted_at: None,
206                conversation_id,
207                message: Some(message.to_string()),
208                message_role: MessageRole::User,
209                message_is_complete: true,
210                used_tokens: estimate_tokens(message),
211                order_number: new_order_number,
212                tool_call_fields: vec![],
213                tool_output: None,
214            },
215        )
216        .await?;
217
218        let mut api_chat_messages: Vec<APIMessage> = conversation_messages
219            .into_iter()
220            .map(APIMessage::try_from)
221            .collect::<ChatbotResult<Vec<_>>>()?;
222
223        // put new user message into the messages list
224        api_chat_messages.push(new_message.clone().try_into()?);
225
226        api_chat_messages.insert(
227            0,
228            APIMessage {
229                role: MessageRole::System,
230                fields: APIMessageKind::Text(APIMessageText {
231                    content: configuration.prompt.clone(),
232                }),
233            },
234        );
235
236        let data_sources = if configuration.use_azure_search {
237            let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
238                anyhow::anyhow!("Azure configuration is missing from the application configuration")
239            })?;
240
241            let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
242                anyhow::anyhow!(
243                    "Azure search configuration is missing from the Azure configuration"
244                )
245            })?;
246
247            let query_type = if configuration.use_semantic_reranking {
248                "vector_semantic_hybrid"
249            } else {
250                "vector_simple_hybrid"
251            };
252
253            // if there are data sources, the message history might contain incompatible
254            // tool call and result messages. Remove tool call messages and turn tool
255            // response messages into role=assistant messages with the tool output as
256            // text content.
257            api_chat_messages = api_chat_messages
258                .into_iter()
259                .filter(|m| !matches!(m.fields, APIMessageKind::ToolCall(_)))
260                .map(|m| match m.fields {
261                    APIMessageKind::ToolResponse(r) => APIMessage {
262                        role: MessageRole::Assistant,
263                        fields: APIMessageKind::Text(APIMessageText { content: r.content }),
264                    },
265                    _ => m,
266                })
267                .collect();
268
269            vec![DataSource {
270                data_type: "azure_search".to_string(),
271                parameters: DataSourceParameters {
272                    endpoint: search_config.search_endpoint.to_string(),
273                    authentication: DataSourceParametersAuthentication {
274                        auth_type: "api_key".to_string(),
275                        key: search_config.search_api_key.clone(),
276                    },
277                    index_name,
278                    query_type: query_type.to_string(),
279                    semantic_configuration: "default".to_string(),
280                    embedding_dependency: EmbeddingDependency {
281                        dep_type: "deployment_name".to_string(),
282                        deployment_name: search_config.vectorizer_deployment_id.clone(),
283                    },
284                    in_scope: false,
285                    top_n_documents: 15,
286                    strictness: 3,
287                    filter: Some(
288                        SearchFilter::eq("course_id", configuration.course_id.to_string())
289                            .to_odata()?,
290                    ),
291                    fields_mapping: FieldsMapping {
292                        content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
293                        content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
294                        filepath_field: "filepath".to_string(),
295                        title_field: "title".to_string(),
296                        url_field: "url".to_string(),
297                        vector_fields: vec!["text_vector".to_string()],
298                    },
299                },
300            }]
301        } else {
302            Vec::new()
303        };
304
305        let tools = if configuration.use_tools {
306            get_chatbot_tool_definitions()
307        } else {
308            Vec::new()
309        };
310
311        let serialized_messages = serde_json::to_string(&api_chat_messages)?;
312        let request_estimated_tokens = estimate_tokens(&serialized_messages);
313
314        let params = if configuration.thinking_model {
315            LLMRequestParams::Thinking(ThinkingParams {
316                max_completion_tokens: Some(configuration.max_completion_tokens),
317                reasoning_effort: Some(configuration.reasoning_effort),
318                verbosity: Some(configuration.verbosity),
319                tools,
320                tool_choice: if configuration.use_tools {
321                    Some(LLMToolChoice::Auto)
322                } else {
323                    None
324                },
325            })
326        } else {
327            LLMRequestParams::NonThinking(NonThinkingParams {
328                max_tokens: Some(configuration.response_max_tokens),
329                temperature: Some(configuration.temperature),
330                top_p: Some(configuration.top_p),
331                frequency_penalty: Some(configuration.frequency_penalty),
332                presence_penalty: Some(configuration.presence_penalty),
333            })
334        };
335
336        Ok((
337            Self {
338                messages: api_chat_messages,
339                data_sources,
340                params,
341                stop: None,
342            },
343            new_message.order_number,
344            request_estimated_tokens,
345        ))
346    }
347
348    pub async fn update_messages_to_db(
349        mut self,
350        conn: &mut PgConnection,
351        new_msgs: Vec<APIMessage>,
352        conversation_id: Uuid,
353        mut order_number: i32,
354    ) -> anyhow::Result<(Self, i32)> {
355        for m in new_msgs {
356            let converted_msg = m.to_chatbot_conversation_message(conversation_id, order_number)?;
357            chatbot_conversation_messages::insert(conn, converted_msg).await?;
358            self.messages.push(m);
359            order_number += 1;
360        }
361        Ok((self, order_number))
362    }
363}
364
365#[derive(Serialize, Deserialize, Debug, Clone)]
366pub struct DataSource {
367    #[serde(rename = "type")]
368    pub data_type: String,
369    pub parameters: DataSourceParameters,
370}
371
372#[derive(Serialize, Deserialize, Debug, Clone)]
373pub struct DataSourceParameters {
374    pub endpoint: String,
375    pub authentication: DataSourceParametersAuthentication,
376    pub index_name: String,
377    pub query_type: String,
378    pub embedding_dependency: EmbeddingDependency,
379    pub in_scope: bool,
380    pub top_n_documents: i32,
381    pub strictness: i32,
382    #[serde(skip_serializing_if = "Option::is_none")]
383    pub filter: Option<String>,
384    pub fields_mapping: FieldsMapping,
385    pub semantic_configuration: String,
386}
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389pub struct DataSourceParametersAuthentication {
390    #[serde(rename = "type")]
391    pub auth_type: String,
392    pub key: String,
393}
394
395#[derive(Serialize, Deserialize, Debug, Clone)]
396pub struct EmbeddingDependency {
397    #[serde(rename = "type")]
398    pub dep_type: String,
399    pub deployment_name: String,
400}
401
402#[derive(Serialize, Deserialize, Debug, Clone)]
403pub struct FieldsMapping {
404    pub content_fields_separator: String,
405    pub content_fields: Vec<String>,
406    pub filepath_field: String,
407    pub title_field: String,
408    pub url_field: String,
409    pub vector_fields: Vec<String>,
410}
411
412#[derive(Serialize, Deserialize, Debug)]
413pub struct ChatResponse {
414    pub text: String,
415}
416
417/// Custom stream that encapsulates both the response stream and the cancellation guard. Makes sure that the guard is always dropped when the stream is dropped.
418#[pin_project]
419struct GuardedStream<S> {
420    guard: RequestCancelledGuard,
421    #[pin]
422    stream: S,
423}
424
425impl<S> GuardedStream<S> {
426    fn new(guard: RequestCancelledGuard, stream: S) -> Self {
427        Self { guard, stream }
428    }
429}
430
431impl<S> Stream for GuardedStream<S>
432where
433    S: Stream<Item = anyhow::Result<Bytes>> + Send,
434{
435    type Item = S::Item;
436
437    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
438        let this = self.project();
439        this.stream.poll_next(cx)
440    }
441}
442
443/// A LinesStream that is peekable. Needed to determine which type of LLM response is
444/// being received.
445type PeekableLinesStream<'a> = Pin<
446    Box<Peekable<LinesStream<StreamReader<BoxStream<'a, Result<Bytes, std::io::Error>>, Bytes>>>>,
447>;
448pub enum ResponseStreamType<'a> {
449    Toolcall(PeekableLinesStream<'a>),
450    TextResponse(PeekableLinesStream<'a>),
451}
452
453struct RequestCancelledGuard {
454    response_message_id: Uuid,
455    received_string: Arc<Mutex<Vec<String>>>,
456    pool: PgPool,
457    done: Arc<AtomicBool>,
458    request_estimated_tokens: i32,
459}
460
461impl Drop for RequestCancelledGuard {
462    fn drop(&mut self) {
463        if self.done.load(atomic::Ordering::Relaxed) {
464            return;
465        }
466        warn!("Request was not cancelled. Cleaning up.");
467        let response_message_id = self.response_message_id;
468        let received_string = self.received_string.clone();
469        let pool = self.pool.clone();
470        let request_estimated_tokens = self.request_estimated_tokens;
471        tokio::spawn(async move {
472            info!("Verifying the received message has been handled");
473            let mut conn = pool.acquire().await.expect("Could not acquire connection");
474            let full_response_text = received_string.lock().await;
475            if full_response_text.is_empty() {
476                info!("No response received. Deleting the response message");
477                models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
478                    .await
479                    .expect("Could not delete response message");
480                return;
481            }
482            info!("Response received but not completed. Saving the text received so far.");
483            let full_response_as_string = full_response_text.join("");
484            let estimated_cost = estimate_tokens(&full_response_as_string);
485            info!(
486                "End of chatbot response stream. Estimated cost: {}. Response: {}",
487                estimated_cost, full_response_as_string
488            );
489
490            // Update with request_estimated_tokens + estimated_cost
491            models::chatbot_conversation_messages::update(
492                &mut conn,
493                response_message_id,
494                &full_response_as_string,
495                true,
496                request_estimated_tokens + estimated_cost,
497            )
498            .await
499            .expect("Could not update response message");
500        });
501    }
502}
503
504pub async fn make_request_and_stream<'a>(
505    chat_request: LLMRequest,
506    model_name: &str,
507    app_config: &ApplicationConfiguration,
508) -> anyhow::Result<ResponseStreamType<'a>> {
509    let response = make_streaming_llm_request(chat_request, model_name, app_config).await?;
510
511    trace!("Receiving chat response with {:?}", response.version());
512
513    if !response.status().is_success() {
514        let status = response.status();
515        let error_message = response.text().await?;
516        return Err(anyhow::anyhow!(
517            "Failed to send chat request. Status: {}. Error: {}",
518            status,
519            error_message
520        ));
521    }
522
523    let stream = response
524        .bytes_stream()
525        .map_err(std::io::Error::other)
526        .boxed();
527    let reader = StreamReader::new(stream);
528    let lines = reader.lines();
529    let lines_stream = LinesStream::new(lines);
530    let peekable_lines_stream = lines_stream.peekable();
531    let mut pinned_lines = Box::pin(peekable_lines_stream);
532
533    loop {
534        let line_res = pinned_lines.as_mut().peek().await;
535        match line_res {
536            None => {
537                break;
538            }
539            Some(Err(e)) => {
540                return Err(anyhow!(
541                    "There was an error streaming response from Azure: {}",
542                    e
543                ));
544            }
545            Some(Result::Ok(line)) => {
546                if !line.starts_with("data: ") {
547                    pinned_lines.next().await;
548                    continue;
549                }
550                let json_str = line.trim_start_matches("data: ");
551                let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
552                    .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {}", e))?;
553                for choice in &response_chunk.choices {
554                    if let Some(d) = &choice.delta {
555                        if d.content.is_some() || d.context.is_some() {
556                            return Ok(ResponseStreamType::TextResponse(pinned_lines));
557                        } else if let Some(_calls) = &d.tool_calls {
558                            return Ok(ResponseStreamType::Toolcall(pinned_lines));
559                        } else if d.content.is_none() {
560                            pinned_lines.next().await;
561                            continue;
562                        }
563                    }
564                }
565                pinned_lines.next().await;
566            }
567        }
568    }
569    Err(Error::msg(
570        "The response received from Azure had an unexpected shape and couldn't be parsed"
571            .to_string(),
572    ))
573}
574
575/// Streams and parses a LLM response from Azure that contains function calls.
576/// Calls the functions and returns a Vec of function results to be sent to Azure.
577pub async fn parse_tool<'a>(
578    conn: &mut PgConnection,
579    mut lines: PeekableLinesStream<'a>,
580    user_context: &ChatbotUserContext,
581) -> anyhow::Result<Vec<APIMessage>> {
582    let mut function_name_id_args: Vec<(String, String, String)> = vec![];
583    let mut currently_streamed_function_name_id: Option<(String, String)> = None;
584    let mut currently_streamed_function_args = vec![];
585    let mut messages = vec![];
586
587    trace!("Parsing tool calls...");
588
589    while let Some(val) = lines.next().await {
590        let line = val?;
591        if !line.to_owned().starts_with("data: ") {
592            continue;
593        }
594        let json_str = line.trim_start_matches("data: ");
595        if json_str.trim() == "[DONE]" {
596            // the stream ended
597            if function_name_id_args.is_empty() {
598                return Err(anyhow::anyhow!(
599                    "The LLM response was supposed to contain function calls, but no function calls were found"
600                ));
601            }
602            let mut assistant_tool_calls = Vec::new();
603            let mut tool_result_msgs = Vec::new();
604
605            for (name, id, args) in function_name_id_args.iter() {
606                let tool = get_chatbot_tool(conn, name, args, user_context).await?;
607
608                assistant_tool_calls.push(APIToolCall {
609                    function: APITool {
610                        name: name.to_owned(),
611                        arguments: serde_json::to_string(tool.get_arguments())?,
612                    },
613                    id: id.to_owned(),
614                    tool_type: ToolCallType::Function,
615                });
616                tool_result_msgs.push(APIMessage {
617                    role: MessageRole::Tool,
618                    fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
619                        content: tool.get_tool_output(),
620                        name: name.to_owned(),
621                        tool_call_id: id.to_owned(),
622                    }),
623                })
624            }
625            // insert all tool calls made by the bot as one message into the messages
626            messages.push(APIMessage {
627                role: MessageRole::Assistant,
628                fields: APIMessageKind::ToolCall(APIMessageToolCall {
629                    tool_calls: assistant_tool_calls,
630                }),
631            });
632            // add tool call output messages to the messages
633            messages.extend(tool_result_msgs);
634            break;
635        }
636        let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
637            .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {} {}", e, json_str))?;
638        for choice in &response_chunk.choices {
639            if Some("tool_calls".to_string()) == choice.finish_reason {
640                // the stream is finished for now because of "tool_calls"
641                // so if we're still streaming some func call, finish it and store it
642                if let Some((name, id)) = &currently_streamed_function_name_id {
643                    // we have streamed some func call and args so let's join the args
644                    // and save the call
645                    let fn_args = currently_streamed_function_args.join("");
646                    function_name_id_args.push((
647                        name.to_owned(),
648                        id.to_owned(),
649                        fn_args.to_owned(),
650                    ));
651                    currently_streamed_function_args.clear();
652                    currently_streamed_function_name_id = None;
653                    // after this chunk, there is assumed to be a chunk that just has
654                    // content "[DONE]", we'll process the func calls at that point.
655                }
656            }
657            if let Some(delta) = &choice.delta
658                && let Some(tool_calls) = &delta.tool_calls
659            {
660                // this chunk has tool call data
661                for call in tool_calls {
662                    if let (Some(name), Some(id)) = (&call.function.name, &call.id) {
663                        // if this chunk has a tool name and id, then a new call is made.
664                        // if there is previously streamed args, then their streaming is
665                        // complete, let's join and save them before processing this new
666                        // call.
667                        if let Some((name_prev, id_prev)) = currently_streamed_function_name_id {
668                            let fn_args = currently_streamed_function_args.join("");
669                            function_name_id_args.push((
670                                name_prev.to_owned(),
671                                id_prev.to_owned(),
672                                fn_args,
673                            ));
674                            currently_streamed_function_args.clear();
675                        }
676                        // set the tool name and id from this chunk to currently_streamed
677                        // and save any arguments to currently_streamed_function_args
678                        // until the stream is complete or a new call is made.
679                        currently_streamed_function_name_id =
680                            Some((name.to_owned(), id.to_owned()));
681                    };
682                    // always save any streamed function args. it can be an empty string
683                    // but that's ok.
684                    currently_streamed_function_args.push(call.function.arguments.clone());
685                }
686            }
687        }
688    }
689    Ok(messages)
690}
691
692/// Streams and parses a LLM response from Azure that contains a text response.
693pub async fn parse_and_stream_to_user<'a>(
694    conn: &mut PgConnection,
695    mut lines: PeekableLinesStream<'a>,
696    conversation_id: Uuid,
697    response_order_number: i32,
698    pool: PgPool,
699    request_estimated_tokens: i32,
700) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send + 'a>>> {
701    // insert the to-be-streamed bot text response to db
702    let response_message = models::chatbot_conversation_messages::insert(
703        conn,
704        ChatbotConversationMessage {
705            id: Uuid::new_v4(),
706            created_at: Utc::now(),
707            updated_at: Utc::now(),
708            deleted_at: None,
709            conversation_id,
710            message: Some("".to_string()),
711            message_role: MessageRole::Assistant,
712            message_is_complete: false,
713            used_tokens: request_estimated_tokens,
714            order_number: response_order_number,
715            tool_call_fields: vec![],
716            tool_output: None,
717        },
718    )
719    .await?;
720
721    let done = Arc::new(AtomicBool::new(false));
722    let full_response_text = Arc::new(Mutex::new(Vec::new()));
723    // Instantiate the guard before creating the stream.
724    let guard = RequestCancelledGuard {
725        response_message_id: response_message.id,
726        received_string: full_response_text.clone(),
727        pool: pool.clone(),
728        done: done.clone(),
729        request_estimated_tokens,
730    };
731
732    trace!("Parsing stream to user...");
733
734    let response_stream = async_stream::try_stream! {
735        while let Some(val) = lines.next().await {
736            let line = val?;
737            if !line.starts_with("data: ") {
738                continue;
739            }
740            let mut full_response_text = full_response_text.lock().await;
741            let json_str = line.trim_start_matches("data: ");
742            if json_str.trim() == "[DONE]" {
743                let full_response_as_string = full_response_text.join("");
744                let estimated_cost = estimate_tokens(&full_response_as_string);
745                trace!(
746                    "End of chatbot response stream. Estimated cost: {}. Response: {}",
747                    estimated_cost, full_response_as_string
748                );
749                done.store(true, atomic::Ordering::Relaxed);
750                let mut conn = pool.acquire().await?;
751                models::chatbot_conversation_messages::update(
752                    &mut conn,
753                    response_message.id,
754                    &full_response_as_string,
755                    true,
756                    request_estimated_tokens + estimated_cost,
757                ).await?;
758                break;
759            }
760            let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
761                anyhow::anyhow!("Failed to parse response chunk: {}", e)
762            })?;
763
764            for choice in &response_chunk.choices {
765                if let Some(delta) = &choice.delta {
766                    if let Some(content) = &delta.content {
767                        full_response_text.push(content.clone());
768                        let response = ChatResponse { text: content.clone() };
769                        let response_as_string = serde_json::to_string(&response)?;
770                        yield Bytes::from(response_as_string);
771                        yield Bytes::from("\n");
772                    }
773                    if let Some(context) = &delta.context {
774                        let mut conn = pool.acquire().await?;
775                        for (idx, cit) in context.citations.iter().enumerate() {
776                            let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
777                            let split = content.split_once(CONTENT_FIELD_SEPARATOR);
778                            if split.is_none() {
779                                error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
780                            }
781                            let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
782
783                            // The title and URL come from Azure Blob Storage metadata, which was URL-encoded
784                            // (percent-encoded) because Azure Blob Storage metadata values must be ASCII-only.
785                            // We decode them back to their original UTF-8 strings before storing in the database.
786                            let decoded_title = url_decode(&cit.title)?;
787                            let decoded_url = url_decode(&cit.url)?;
788
789                            let mut page_path = PathBuf::from(&cit.filepath);
790                            page_path.set_extension("");
791                            let page_id_str = page_path.file_name();
792                            let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
793                            let course_material_chapter_number = if let Some(id) = page_id {
794                                let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
795                                chapter.map(|c| c.chapter_number)
796                            } else {
797                                None
798                            };
799
800                            models::chatbot_conversation_messages_citations::insert(
801                                &mut conn, ChatbotConversationMessageCitation {
802                                    id: Uuid::new_v4(),
803                                    created_at: Utc::now(),
804                                    updated_at: Utc::now(),
805                                    deleted_at: None,
806                                    conversation_message_id: response_message.id,
807                                    conversation_id: response_message.conversation_id,
808                                    course_material_chapter_number,
809                                    title: decoded_title,
810                                    content: cleaned_content,
811                                    document_url: decoded_url,
812                                    citation_number: (idx+1) as i32,
813                                }
814                            ).await?;
815                        }
816                    }
817                }
818            }
819        }
820
821        if !done.load(atomic::Ordering::Relaxed) {
822            Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
823        }
824    };
825
826    // Encapsulate the stream and the guard within GuardedStream. This moves the request guard into the stream and ensures that it is dropped when the stream is dropped.
827    // This way we do cleanup only when the stream is dropped and not when this function returns.
828    let guarded_stream = GuardedStream::new(guard, response_stream);
829
830    // Box and pin the GuardedStream to satisfy the Unpin requirement
831    Ok(Box::pin(guarded_stream))
832}
833
834pub async fn send_chat_request_and_parse_stream(
835    conn: &mut PgConnection,
836    pool: PgPool,
837    app_config: &ApplicationConfiguration,
838    chatbot_configuration_id: Uuid,
839    conversation_id: Uuid,
840    message: &str,
841    user_context: ChatbotUserContext,
842) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
843    let (mut chat_request, new_message_order_number, request_estimated_tokens) =
844        LLMRequest::build_and_insert_incoming_message_to_db(
845            conn,
846            chatbot_configuration_id,
847            conversation_id,
848            message,
849            app_config,
850        )
851        .await?;
852
853    let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
854        conn,
855        chatbot_configuration_id,
856    )
857    .await?;
858
859    let mut next_message_order_number = new_message_order_number + 1;
860    let mut max_iterations_left = 15;
861
862    loop {
863        max_iterations_left -= 1;
864        if max_iterations_left == 0 {
865            error!("Maximum tool call iterations exceeded");
866            return Err(anyhow::anyhow!(
867                "Maximum tool call iterations exceeded. The LLM may be stuck in a loop."
868            ));
869        }
870
871        let response_type =
872            make_request_and_stream(chat_request.clone(), &model.deployment_name, app_config)
873                .await?;
874
875        let new_tool_msgs = match response_type {
876            ResponseStreamType::Toolcall(stream) => parse_tool(conn, stream, &user_context).await?,
877            ResponseStreamType::TextResponse(stream) => {
878                return parse_and_stream_to_user(
879                    conn,
880                    stream,
881                    conversation_id,
882                    next_message_order_number,
883                    pool,
884                    request_estimated_tokens,
885                )
886                .await;
887            }
888        };
889        (chat_request, next_message_order_number) = chat_request
890            .update_messages_to_db(
891                conn,
892                new_tool_msgs,
893                conversation_id,
894                next_message_order_number,
895            )
896            .await?;
897    }
898}