headless_lms_chatbot/
azure_chatbot.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::sync::{
5    Arc,
6    atomic::{self, AtomicBool},
7};
8use std::task::{Context, Poll};
9
10use anyhow::{Error, Ok, anyhow};
11use bytes::Bytes;
12use chrono::Utc;
13use futures::stream::{BoxStream, Peekable};
14use futures::{Stream, StreamExt, TryStreamExt};
15use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
16use headless_lms_models::chatbot_conversation_messages::{
17    self, ChatbotConversationMessage, MessageRole,
18};
19use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
20use headless_lms_utils::ApplicationConfiguration;
21use pin_project::pin_project;
22use serde::{Deserialize, Serialize};
23use sqlx::PgPool;
24use tokio::{io::AsyncBufReadExt, sync::Mutex};
25use tokio_stream::wrappers::LinesStream;
26use tokio_util::io::StreamReader;
27use tracing::trace;
28use url::Url;
29
30use crate::chatbot_error::ChatbotResult;
31use crate::chatbot_tools::{
32    AzureLLMToolDefinition, ChatbotTool, get_chatbot_tool, get_chatbot_tool_definitions,
33};
34use crate::llm_utils::{
35    APIMessage, APIMessageKind, APIMessageText, APIMessageToolCall, APIMessageToolResponse,
36    APITool, APIToolCall, estimate_tokens, make_streaming_llm_request,
37};
38use headless_lms_utils::url_encoding::url_decode;
39
40use crate::prelude::*;
41use crate::search_filter::SearchFilter;
42
43const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
44
45/// Context about the user and course for a chatbot interaction.
46/// Passed to tool implementations so they can access user-specific data.
47pub struct ChatbotUserContext {
48    pub user_id: Uuid,
49    pub course_id: Uuid,
50    pub course_name: String,
51}
52
53#[derive(Deserialize, Serialize, Debug)]
54pub struct ContentFilterResults {
55    pub hate: Option<ContentFilter>,
56    pub self_harm: Option<ContentFilter>,
57    pub sexual: Option<ContentFilter>,
58    pub violence: Option<ContentFilter>,
59}
60
61#[derive(Deserialize, Serialize, Debug)]
62pub struct ContentFilter {
63    pub filtered: bool,
64    pub severity: String,
65}
66
67/// Data in a streamed response chunk
68#[derive(Deserialize, Serialize, Debug)]
69pub struct Choice {
70    pub content_filter_results: Option<ContentFilterResults>,
71    pub delta: Option<Delta>,
72    pub finish_reason: Option<String>,
73    pub index: i32,
74}
75
76/// Content in a streamed response chunk Choice
77#[derive(Deserialize, Serialize, Debug)]
78pub struct Delta {
79    pub content: Option<String>,
80    pub context: Option<DeltaContext>,
81    pub tool_calls: Option<Vec<ToolCallInDelta>>,
82}
83
84#[derive(Deserialize, Serialize, Debug)]
85pub struct DeltaContext {
86    pub citations: Vec<Citation>,
87}
88
89/// A streamed tool call from Azure
90#[derive(Deserialize, Serialize, Debug)]
91pub struct ToolCallInDelta {
92    pub id: Option<String>,
93    pub function: DeltaTool,
94    #[serde(rename = "type")]
95    pub tool_type: Option<ToolCallType>,
96}
97
98/// Streamed tool call content
99#[derive(Deserialize, Serialize, Debug, Clone)]
100pub struct DeltaTool {
101    #[serde(default)]
102    pub arguments: String,
103    pub name: Option<String>,
104}
105
106#[derive(Serialize, Deserialize, Debug, Clone)]
107#[serde(rename_all = "snake_case")]
108pub enum ToolCallType {
109    Function,
110}
111
112#[derive(Deserialize, Serialize, Debug)]
113pub struct Citation {
114    pub content: String,
115    pub title: String,
116    pub url: String,
117    pub filepath: String,
118}
119
120/// Response received from LLM API
121#[derive(Deserialize, Serialize, Debug)]
122pub struct ResponseChunk {
123    pub choices: Vec<Choice>,
124    pub created: u64,
125    pub id: String,
126    pub model: String,
127    pub object: String,
128    pub system_fingerprint: Option<String>,
129}
130
131#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
132#[serde(rename_all = "snake_case")]
133pub enum LLMToolChoice {
134    Auto,
135}
136
137#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
138pub struct ThinkingParams {
139    pub max_completion_tokens: Option<i32>,
140    pub verbosity: Option<VerbosityLevel>,
141    pub reasoning_effort: Option<ReasoningEffortLevel>,
142    #[serde(skip_serializing_if = "Vec::is_empty")]
143    pub tools: Vec<AzureLLMToolDefinition>,
144    pub tool_choice: Option<LLMToolChoice>,
145}
146
147#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
148pub struct NonThinkingParams {
149    pub max_tokens: Option<i32>,
150    pub temperature: Option<f32>,
151    pub top_p: Option<f32>,
152    pub frequency_penalty: Option<f32>,
153    pub presence_penalty: Option<f32>,
154}
155
156#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
157#[serde(untagged)]
158pub enum LLMRequestParams {
159    Thinking(ThinkingParams),
160    NonThinking(NonThinkingParams),
161}
162
163#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
164#[serde(rename_all = "snake_case")]
165pub enum JSONType {
166    JsonSchema,
167    Object,
168    Array,
169    String,
170}
171
172/// Schema for defining structured LLM output
173#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
174pub struct JSONSchema {
175    pub name: String,
176    pub strict: bool,
177    pub schema: Schema,
178}
179
180/// Defines LLM structured output shape and types
181#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
182#[serde(rename_all = "camelCase")]
183pub struct Schema {
184    #[serde(rename = "type")]
185    /// Type of the schema, should be Object
186    pub type_field: JSONType,
187    // only array-type properties are supported for now
188    pub properties: HashMap<String, ArrayProperty>,
189    /// All 'properties' keys must be included in this 'required' list
190    pub required: Vec<String>,
191    /// additionalProperties should always be 'false'
192    pub additional_properties: bool,
193}
194
195#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
196pub struct ArrayProperty {
197    #[serde(rename = "type")]
198    pub type_field: JSONType,
199    pub items: ArrayItem,
200}
201
202#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
203pub struct ArrayItem {
204    #[serde(rename = "type")]
205    pub type_field: JSONType,
206}
207
208#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
209pub struct LLMRequestResponseFormatParam {
210    #[serde(rename = "type")]
211    pub format_type: JSONType, //should be JsonSchema
212    pub json_schema: JSONSchema,
213}
214
215#[derive(Serialize, Deserialize, Debug, Clone)]
216pub struct LLMRequest {
217    pub messages: Vec<APIMessage>,
218    #[serde(skip_serializing_if = "Vec::is_empty", default)]
219    pub data_sources: Vec<DataSource>,
220    #[serde(flatten)]
221    pub params: LLMRequestParams,
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub response_format: Option<LLMRequestResponseFormatParam>,
224    pub stop: Option<String>,
225}
226
227impl LLMRequest {
228    pub async fn build_and_insert_incoming_message_to_db(
229        conn: &mut PgConnection,
230        chatbot_configuration_id: Uuid,
231        conversation_id: Uuid,
232        message: &str,
233        app_config: &ApplicationConfiguration,
234    ) -> anyhow::Result<(Self, i32, i32)> {
235        let index_name = Url::parse(&app_config.base_url)?
236            .host_str()
237            .expect("BASE_URL must have a host")
238            .replace(".", "-");
239
240        let configuration =
241            models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
242
243        let conversation_messages =
244            models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
245                .await?;
246
247        let new_order_number = conversation_messages
248            .iter()
249            .map(|m| m.order_number)
250            .max()
251            .unwrap_or(0)
252            + 1;
253
254        let new_message = models::chatbot_conversation_messages::insert(
255            conn,
256            ChatbotConversationMessage {
257                id: Uuid::new_v4(),
258                created_at: Utc::now(),
259                updated_at: Utc::now(),
260                deleted_at: None,
261                conversation_id,
262                message: Some(message.to_string()),
263                message_role: MessageRole::User,
264                message_is_complete: true,
265                used_tokens: estimate_tokens(message),
266                order_number: new_order_number,
267                tool_call_fields: vec![],
268                tool_output: None,
269            },
270        )
271        .await?;
272
273        let mut api_chat_messages: Vec<APIMessage> = conversation_messages
274            .into_iter()
275            .map(APIMessage::try_from)
276            .collect::<ChatbotResult<Vec<_>>>()?;
277
278        // put new user message into the messages list
279        api_chat_messages.push(new_message.clone().try_into()?);
280
281        api_chat_messages.insert(
282            0,
283            APIMessage {
284                role: MessageRole::System,
285                fields: APIMessageKind::Text(APIMessageText {
286                    content: configuration.prompt.clone(),
287                }),
288            },
289        );
290
291        let data_sources = if configuration.use_azure_search {
292            let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
293                anyhow::anyhow!("Azure configuration is missing from the application configuration")
294            })?;
295
296            let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
297                anyhow::anyhow!(
298                    "Azure search configuration is missing from the Azure configuration"
299                )
300            })?;
301
302            let query_type = if configuration.use_semantic_reranking {
303                "vector_semantic_hybrid"
304            } else {
305                "vector_simple_hybrid"
306            };
307
308            // if there are data sources, the message history might contain incompatible
309            // tool call and result messages. Remove tool call messages and turn tool
310            // response messages into role=assistant messages with the tool output as
311            // text content.
312            api_chat_messages = api_chat_messages
313                .into_iter()
314                .filter(|m| !matches!(m.fields, APIMessageKind::ToolCall(_)))
315                .map(|m| match m.fields {
316                    APIMessageKind::ToolResponse(r) => APIMessage {
317                        role: MessageRole::Assistant,
318                        fields: APIMessageKind::Text(APIMessageText { content: r.content }),
319                    },
320                    _ => m,
321                })
322                .collect();
323
324            vec![DataSource {
325                data_type: "azure_search".to_string(),
326                parameters: DataSourceParameters {
327                    endpoint: search_config.search_endpoint.to_string(),
328                    authentication: DataSourceParametersAuthentication {
329                        auth_type: "api_key".to_string(),
330                        key: search_config.search_api_key.clone(),
331                    },
332                    index_name,
333                    query_type: query_type.to_string(),
334                    semantic_configuration: "default".to_string(),
335                    embedding_dependency: EmbeddingDependency {
336                        dep_type: "deployment_name".to_string(),
337                        deployment_name: search_config.vectorizer_deployment_id.clone(),
338                    },
339                    in_scope: false,
340                    top_n_documents: 15,
341                    strictness: 3,
342                    filter: Some(
343                        SearchFilter::eq("course_id", configuration.course_id.to_string())
344                            .to_odata()?,
345                    ),
346                    fields_mapping: FieldsMapping {
347                        content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
348                        content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
349                        filepath_field: "filepath".to_string(),
350                        title_field: "title".to_string(),
351                        url_field: "url".to_string(),
352                        vector_fields: vec!["text_vector".to_string()],
353                    },
354                },
355            }]
356        } else {
357            Vec::new()
358        };
359
360        let tools = if configuration.use_tools {
361            get_chatbot_tool_definitions()
362        } else {
363            Vec::new()
364        };
365
366        let serialized_messages = serde_json::to_string(&api_chat_messages)?;
367        let request_estimated_tokens = estimate_tokens(&serialized_messages);
368
369        let params = if configuration.thinking_model {
370            LLMRequestParams::Thinking(ThinkingParams {
371                max_completion_tokens: Some(configuration.max_completion_tokens),
372                reasoning_effort: Some(configuration.reasoning_effort),
373                verbosity: Some(configuration.verbosity),
374                tools,
375                tool_choice: if configuration.use_tools {
376                    Some(LLMToolChoice::Auto)
377                } else {
378                    None
379                },
380            })
381        } else {
382            LLMRequestParams::NonThinking(NonThinkingParams {
383                max_tokens: Some(configuration.response_max_tokens),
384                temperature: Some(configuration.temperature),
385                top_p: Some(configuration.top_p),
386                frequency_penalty: Some(configuration.frequency_penalty),
387                presence_penalty: Some(configuration.presence_penalty),
388            })
389        };
390
391        Ok((
392            Self {
393                messages: api_chat_messages,
394                data_sources,
395                params,
396                response_format: None,
397                stop: None,
398            },
399            new_message.order_number,
400            request_estimated_tokens,
401        ))
402    }
403
404    pub async fn update_messages_to_db(
405        mut self,
406        conn: &mut PgConnection,
407        new_msgs: Vec<APIMessage>,
408        conversation_id: Uuid,
409        mut order_number: i32,
410    ) -> anyhow::Result<(Self, i32)> {
411        for m in new_msgs {
412            let converted_msg = m.to_chatbot_conversation_message(conversation_id, order_number)?;
413            chatbot_conversation_messages::insert(conn, converted_msg).await?;
414            self.messages.push(m);
415            order_number += 1;
416        }
417        Ok((self, order_number))
418    }
419}
420
421#[derive(Serialize, Deserialize, Debug, Clone)]
422pub struct DataSource {
423    #[serde(rename = "type")]
424    pub data_type: String,
425    pub parameters: DataSourceParameters,
426}
427
428#[derive(Serialize, Deserialize, Debug, Clone)]
429pub struct DataSourceParameters {
430    pub endpoint: String,
431    pub authentication: DataSourceParametersAuthentication,
432    pub index_name: String,
433    pub query_type: String,
434    pub embedding_dependency: EmbeddingDependency,
435    pub in_scope: bool,
436    pub top_n_documents: i32,
437    pub strictness: i32,
438    #[serde(skip_serializing_if = "Option::is_none")]
439    pub filter: Option<String>,
440    pub fields_mapping: FieldsMapping,
441    pub semantic_configuration: String,
442}
443
444#[derive(Serialize, Deserialize, Debug, Clone)]
445pub struct DataSourceParametersAuthentication {
446    #[serde(rename = "type")]
447    pub auth_type: String,
448    pub key: String,
449}
450
451#[derive(Serialize, Deserialize, Debug, Clone)]
452pub struct EmbeddingDependency {
453    #[serde(rename = "type")]
454    pub dep_type: String,
455    pub deployment_name: String,
456}
457
458#[derive(Serialize, Deserialize, Debug, Clone)]
459pub struct FieldsMapping {
460    pub content_fields_separator: String,
461    pub content_fields: Vec<String>,
462    pub filepath_field: String,
463    pub title_field: String,
464    pub url_field: String,
465    pub vector_fields: Vec<String>,
466}
467
468#[derive(Serialize, Deserialize, Debug)]
469pub struct ChatResponse {
470    pub text: String,
471}
472
473/// Custom stream that encapsulates both the response stream and the cancellation guard. Makes sure that the guard is always dropped when the stream is dropped.
474#[pin_project]
475struct GuardedStream<S> {
476    guard: RequestCancelledGuard,
477    #[pin]
478    stream: S,
479}
480
481impl<S> GuardedStream<S> {
482    fn new(guard: RequestCancelledGuard, stream: S) -> Self {
483        Self { guard, stream }
484    }
485}
486
487impl<S> Stream for GuardedStream<S>
488where
489    S: Stream<Item = anyhow::Result<Bytes>> + Send,
490{
491    type Item = S::Item;
492
493    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
494        let this = self.project();
495        this.stream.poll_next(cx)
496    }
497}
498
499/// A LinesStream that is peekable. Needed to determine which type of LLM response is
500/// being received.
501type PeekableLinesStream<'a> = Pin<
502    Box<Peekable<LinesStream<StreamReader<BoxStream<'a, Result<Bytes, std::io::Error>>, Bytes>>>>,
503>;
504pub enum ResponseStreamType<'a> {
505    Toolcall(PeekableLinesStream<'a>),
506    TextResponse(PeekableLinesStream<'a>),
507}
508
509struct RequestCancelledGuard {
510    response_message_id: Uuid,
511    received_string: Arc<Mutex<Vec<String>>>,
512    pool: PgPool,
513    done: Arc<AtomicBool>,
514    request_estimated_tokens: i32,
515}
516
517impl Drop for RequestCancelledGuard {
518    fn drop(&mut self) {
519        if self.done.load(atomic::Ordering::Relaxed) {
520            return;
521        }
522        warn!("Request was not cancelled. Cleaning up.");
523        let response_message_id = self.response_message_id;
524        let received_string = self.received_string.clone();
525        let pool = self.pool.clone();
526        let request_estimated_tokens = self.request_estimated_tokens;
527        tokio::spawn(async move {
528            info!("Verifying the received message has been handled");
529            let mut conn = pool.acquire().await.expect("Could not acquire connection");
530            let full_response_text = received_string.lock().await;
531            if full_response_text.is_empty() {
532                info!("No response received. Deleting the response message");
533                models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
534                    .await
535                    .expect("Could not delete response message");
536                return;
537            }
538            info!("Response received but not completed. Saving the text received so far.");
539            let full_response_as_string = full_response_text.join("");
540            let estimated_cost = estimate_tokens(&full_response_as_string);
541            info!(
542                "End of chatbot response stream. Estimated cost: {}. Response: {}",
543                estimated_cost, full_response_as_string
544            );
545
546            // Update with request_estimated_tokens + estimated_cost
547            models::chatbot_conversation_messages::update(
548                &mut conn,
549                response_message_id,
550                &full_response_as_string,
551                true,
552                request_estimated_tokens + estimated_cost,
553            )
554            .await
555            .expect("Could not update response message");
556        });
557    }
558}
559
560pub async fn make_request_and_stream<'a>(
561    chat_request: LLMRequest,
562    model_name: &str,
563    app_config: &ApplicationConfiguration,
564) -> anyhow::Result<ResponseStreamType<'a>> {
565    let response = make_streaming_llm_request(chat_request, model_name, app_config).await?;
566
567    trace!("Receiving chat response with {:?}", response.version());
568
569    if !response.status().is_success() {
570        let status = response.status();
571        let error_message = response.text().await?;
572        return Err(anyhow::anyhow!(
573            "Failed to send chat request. Status: {}. Error: {}",
574            status,
575            error_message
576        ));
577    }
578
579    let stream = response
580        .bytes_stream()
581        .map_err(std::io::Error::other)
582        .boxed();
583    let reader = StreamReader::new(stream);
584    let lines = reader.lines();
585    let lines_stream = LinesStream::new(lines);
586    let peekable_lines_stream = lines_stream.peekable();
587    let mut pinned_lines = Box::pin(peekable_lines_stream);
588
589    loop {
590        let line_res = pinned_lines.as_mut().peek().await;
591        match line_res {
592            None => {
593                break;
594            }
595            Some(Err(e)) => {
596                return Err(anyhow!(
597                    "There was an error streaming response from Azure: {}",
598                    e
599                ));
600            }
601            Some(Result::Ok(line)) => {
602                if !line.starts_with("data: ") {
603                    pinned_lines.next().await;
604                    continue;
605                }
606                let json_str = line.trim_start_matches("data: ");
607                let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
608                    .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {}", e))?;
609                for choice in &response_chunk.choices {
610                    if let Some(d) = &choice.delta {
611                        if d.content.is_some() || d.context.is_some() {
612                            return Ok(ResponseStreamType::TextResponse(pinned_lines));
613                        } else if let Some(_calls) = &d.tool_calls {
614                            return Ok(ResponseStreamType::Toolcall(pinned_lines));
615                        } else if d.content.is_none() {
616                            pinned_lines.next().await;
617                            continue;
618                        }
619                    }
620                }
621                pinned_lines.next().await;
622            }
623        }
624    }
625    Err(Error::msg(
626        "The response received from Azure had an unexpected shape and couldn't be parsed"
627            .to_string(),
628    ))
629}
630
631/// Streams and parses a LLM response from Azure that contains function calls.
632/// Calls the functions and returns a Vec of function results to be sent to Azure.
633pub async fn parse_tool<'a>(
634    conn: &mut PgConnection,
635    mut lines: PeekableLinesStream<'a>,
636    user_context: &ChatbotUserContext,
637) -> anyhow::Result<Vec<APIMessage>> {
638    let mut function_name_id_args: Vec<(String, String, String)> = vec![];
639    let mut currently_streamed_function_name_id: Option<(String, String)> = None;
640    let mut currently_streamed_function_args = vec![];
641    let mut messages = vec![];
642
643    trace!("Parsing tool calls...");
644
645    while let Some(val) = lines.next().await {
646        let line = val?;
647        if !line.to_owned().starts_with("data: ") {
648            continue;
649        }
650        let json_str = line.trim_start_matches("data: ");
651        if json_str.trim() == "[DONE]" {
652            // the stream ended
653            if function_name_id_args.is_empty() {
654                return Err(anyhow::anyhow!(
655                    "The LLM response was supposed to contain function calls, but no function calls were found"
656                ));
657            }
658            let mut assistant_tool_calls = Vec::new();
659            let mut tool_result_msgs = Vec::new();
660
661            for (name, id, args) in function_name_id_args.iter() {
662                let tool = get_chatbot_tool(conn, name, args, user_context).await?;
663
664                assistant_tool_calls.push(APIToolCall {
665                    function: APITool {
666                        name: name.to_owned(),
667                        arguments: serde_json::to_string(tool.get_arguments())?,
668                    },
669                    id: id.to_owned(),
670                    tool_type: ToolCallType::Function,
671                });
672                tool_result_msgs.push(APIMessage {
673                    role: MessageRole::Tool,
674                    fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
675                        content: tool.get_tool_output(),
676                        name: name.to_owned(),
677                        tool_call_id: id.to_owned(),
678                    }),
679                })
680            }
681            // insert all tool calls made by the bot as one message into the messages
682            messages.push(APIMessage {
683                role: MessageRole::Assistant,
684                fields: APIMessageKind::ToolCall(APIMessageToolCall {
685                    tool_calls: assistant_tool_calls,
686                }),
687            });
688            // add tool call output messages to the messages
689            messages.extend(tool_result_msgs);
690            break;
691        }
692        let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
693            .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {} {}", e, json_str))?;
694        for choice in &response_chunk.choices {
695            if Some("tool_calls".to_string()) == choice.finish_reason {
696                // the stream is finished for now because of "tool_calls"
697                // so if we're still streaming some func call, finish it and store it
698                if let Some((name, id)) = &currently_streamed_function_name_id {
699                    // we have streamed some func call and args so let's join the args
700                    // and save the call
701                    let fn_args = currently_streamed_function_args.join("");
702                    function_name_id_args.push((
703                        name.to_owned(),
704                        id.to_owned(),
705                        fn_args.to_owned(),
706                    ));
707                    currently_streamed_function_args.clear();
708                    currently_streamed_function_name_id = None;
709                    // after this chunk, there is assumed to be a chunk that just has
710                    // content "[DONE]", we'll process the func calls at that point.
711                }
712            }
713            if let Some(delta) = &choice.delta
714                && let Some(tool_calls) = &delta.tool_calls
715            {
716                // this chunk has tool call data
717                for call in tool_calls {
718                    if let (Some(name), Some(id)) = (&call.function.name, &call.id) {
719                        // if this chunk has a tool name and id, then a new call is made.
720                        // if there is previously streamed args, then their streaming is
721                        // complete, let's join and save them before processing this new
722                        // call.
723                        if let Some((name_prev, id_prev)) = currently_streamed_function_name_id {
724                            let fn_args = currently_streamed_function_args.join("");
725                            function_name_id_args.push((
726                                name_prev.to_owned(),
727                                id_prev.to_owned(),
728                                fn_args,
729                            ));
730                            currently_streamed_function_args.clear();
731                        }
732                        // set the tool name and id from this chunk to currently_streamed
733                        // and save any arguments to currently_streamed_function_args
734                        // until the stream is complete or a new call is made.
735                        currently_streamed_function_name_id =
736                            Some((name.to_owned(), id.to_owned()));
737                    };
738                    // always save any streamed function args. it can be an empty string
739                    // but that's ok.
740                    currently_streamed_function_args.push(call.function.arguments.clone());
741                }
742            }
743        }
744    }
745    Ok(messages)
746}
747
748/// Streams and parses a LLM response from Azure that contains a text response.
749pub async fn parse_and_stream_to_user<'a>(
750    conn: &mut PgConnection,
751    mut lines: PeekableLinesStream<'a>,
752    conversation_id: Uuid,
753    response_order_number: i32,
754    pool: PgPool,
755    request_estimated_tokens: i32,
756) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send + 'a>>> {
757    // insert the to-be-streamed bot text response to db
758    let response_message = models::chatbot_conversation_messages::insert(
759        conn,
760        ChatbotConversationMessage {
761            id: Uuid::new_v4(),
762            created_at: Utc::now(),
763            updated_at: Utc::now(),
764            deleted_at: None,
765            conversation_id,
766            message: Some("".to_string()),
767            message_role: MessageRole::Assistant,
768            message_is_complete: false,
769            used_tokens: request_estimated_tokens,
770            order_number: response_order_number,
771            tool_call_fields: vec![],
772            tool_output: None,
773        },
774    )
775    .await?;
776
777    let done = Arc::new(AtomicBool::new(false));
778    let full_response_text = Arc::new(Mutex::new(Vec::new()));
779    // Instantiate the guard before creating the stream.
780    let guard = RequestCancelledGuard {
781        response_message_id: response_message.id,
782        received_string: full_response_text.clone(),
783        pool: pool.clone(),
784        done: done.clone(),
785        request_estimated_tokens,
786    };
787
788    trace!("Parsing stream to user...");
789
790    let response_stream = async_stream::try_stream! {
791        while let Some(val) = lines.next().await {
792            let line = val?;
793            if !line.starts_with("data: ") {
794                continue;
795            }
796            let mut full_response_text = full_response_text.lock().await;
797            let json_str = line.trim_start_matches("data: ");
798            if json_str.trim() == "[DONE]" {
799                let full_response_as_string = full_response_text.join("");
800                let estimated_cost = estimate_tokens(&full_response_as_string);
801                trace!(
802                    "End of chatbot response stream. Estimated cost: {}. Response: {}",
803                    estimated_cost, full_response_as_string
804                );
805                done.store(true, atomic::Ordering::Relaxed);
806                let mut conn = pool.acquire().await?;
807                models::chatbot_conversation_messages::update(
808                    &mut conn,
809                    response_message.id,
810                    &full_response_as_string,
811                    true,
812                    request_estimated_tokens + estimated_cost,
813                ).await?;
814                break;
815            }
816            let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
817                anyhow::anyhow!("Failed to parse response chunk: {}", e)
818            })?;
819
820            for choice in &response_chunk.choices {
821                if let Some(delta) = &choice.delta {
822                    if let Some(content) = &delta.content {
823                        full_response_text.push(content.clone());
824                        let response = ChatResponse { text: content.clone() };
825                        let response_as_string = serde_json::to_string(&response)?;
826                        yield Bytes::from(response_as_string);
827                        yield Bytes::from("\n");
828                    }
829                    if let Some(context) = &delta.context {
830                        let mut conn = pool.acquire().await?;
831                        for (idx, cit) in context.citations.iter().enumerate() {
832                            let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
833                            let split = content.split_once(CONTENT_FIELD_SEPARATOR);
834                            if split.is_none() {
835                                error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
836                            }
837                            let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
838
839                            // The title and URL come from Azure Blob Storage metadata, which was URL-encoded
840                            // (percent-encoded) because Azure Blob Storage metadata values must be ASCII-only.
841                            // We decode them back to their original UTF-8 strings before storing in the database.
842                            let decoded_title = url_decode(&cit.title)?;
843                            let decoded_url = url_decode(&cit.url)?;
844
845                            let mut page_path = PathBuf::from(&cit.filepath);
846                            page_path.set_extension("");
847                            let page_id_str = page_path.file_name();
848                            let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
849                            let course_material_chapter_number = if let Some(id) = page_id {
850                                let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
851                                chapter.map(|c| c.chapter_number)
852                            } else {
853                                None
854                            };
855
856                            models::chatbot_conversation_messages_citations::insert(
857                                &mut conn, ChatbotConversationMessageCitation {
858                                    id: Uuid::new_v4(),
859                                    created_at: Utc::now(),
860                                    updated_at: Utc::now(),
861                                    deleted_at: None,
862                                    conversation_message_id: response_message.id,
863                                    conversation_id: response_message.conversation_id,
864                                    course_material_chapter_number,
865                                    title: decoded_title,
866                                    content: cleaned_content,
867                                    document_url: decoded_url,
868                                    citation_number: (idx+1) as i32,
869                                }
870                            ).await?;
871                        }
872                    }
873                }
874            }
875        }
876
877        if !done.load(atomic::Ordering::Relaxed) {
878            Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
879        }
880    };
881
882    // Encapsulate the stream and the guard within GuardedStream. This moves the request guard into the stream and ensures that it is dropped when the stream is dropped.
883    // This way we do cleanup only when the stream is dropped and not when this function returns.
884    let guarded_stream = GuardedStream::new(guard, response_stream);
885
886    // Box and pin the GuardedStream to satisfy the Unpin requirement
887    Ok(Box::pin(guarded_stream))
888}
889
890pub async fn send_chat_request_and_parse_stream(
891    conn: &mut PgConnection,
892    pool: PgPool,
893    app_config: &ApplicationConfiguration,
894    chatbot_configuration_id: Uuid,
895    conversation_id: Uuid,
896    message: &str,
897    user_context: ChatbotUserContext,
898) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
899    let (mut chat_request, new_message_order_number, request_estimated_tokens) =
900        LLMRequest::build_and_insert_incoming_message_to_db(
901            conn,
902            chatbot_configuration_id,
903            conversation_id,
904            message,
905            app_config,
906        )
907        .await?;
908
909    let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
910        conn,
911        chatbot_configuration_id,
912    )
913    .await?;
914
915    let mut next_message_order_number = new_message_order_number + 1;
916    let mut max_iterations_left = 15;
917
918    loop {
919        max_iterations_left -= 1;
920        if max_iterations_left == 0 {
921            error!("Maximum tool call iterations exceeded");
922            return Err(anyhow::anyhow!(
923                "Maximum tool call iterations exceeded. The LLM may be stuck in a loop."
924            ));
925        }
926
927        let response_type =
928            make_request_and_stream(chat_request.clone(), &model.deployment_name, app_config)
929                .await?;
930
931        let new_tool_msgs = match response_type {
932            ResponseStreamType::Toolcall(stream) => parse_tool(conn, stream, &user_context).await?,
933            ResponseStreamType::TextResponse(stream) => {
934                return parse_and_stream_to_user(
935                    conn,
936                    stream,
937                    conversation_id,
938                    next_message_order_number,
939                    pool,
940                    request_estimated_tokens,
941                )
942                .await;
943            }
944        };
945        (chat_request, next_message_order_number) = chat_request
946            .update_messages_to_db(
947                conn,
948                new_tool_msgs,
949                conversation_id,
950                next_message_order_number,
951            )
952            .await?;
953    }
954}