Skip to main content

headless_lms_chatbot/
azure_chatbot.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::sync::{
5    Arc,
6    atomic::{self, AtomicBool},
7};
8use std::task::{Context, Poll};
9
10use anyhow::{Error, Ok, anyhow};
11use bytes::Bytes;
12use chrono::Utc;
13use futures::stream::{BoxStream, Peekable};
14use futures::{Stream, StreamExt, TryStreamExt};
15use headless_lms_base::config::ApplicationConfiguration;
16use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
17use headless_lms_models::chatbot_conversation_messages::{
18    self, ChatbotConversationMessage, MessageRole,
19};
20use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
21use pin_project::pin_project;
22use serde::{Deserialize, Serialize};
23use sqlx::PgPool;
24use tokio::{io::AsyncBufReadExt, sync::Mutex};
25use tokio_stream::wrappers::LinesStream;
26use tokio_util::io::StreamReader;
27use tracing::trace;
28use url::Url;
29
30use crate::chatbot_error::ChatbotResult;
31use crate::chatbot_tools::{
32    AzureLLMToolDefinition, ChatbotTool, get_chatbot_tool, get_chatbot_tool_definitions,
33};
34use crate::llm_utils::{
35    APIMessage, APIMessageKind, APIMessageText, APIMessageToolCall, APIMessageToolResponse,
36    APITool, APIToolCall, estimate_tokens, make_streaming_llm_request,
37};
38use headless_lms_utils::strings::truncate_utf8_at_boundary;
39use headless_lms_utils::url_encoding::url_decode;
40
41use crate::prelude::*;
42use crate::search_filter::SearchFilter;
43
44const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
45
46/// Context about the user and course for a chatbot interaction.
47/// Passed to tool implementations so they can access user-specific data.
48pub struct ChatbotUserContext {
49    pub user_id: Uuid,
50    pub course_id: Uuid,
51    pub course_name: String,
52}
53
54#[derive(Deserialize, Serialize, Debug)]
55pub struct ContentFilterResults {
56    pub hate: Option<ContentFilter>,
57    pub self_harm: Option<ContentFilter>,
58    pub sexual: Option<ContentFilter>,
59    pub violence: Option<ContentFilter>,
60}
61
62#[derive(Deserialize, Serialize, Debug)]
63pub struct ContentFilter {
64    pub filtered: bool,
65    pub severity: String,
66}
67
68/// Data in a streamed response chunk
69#[derive(Deserialize, Serialize, Debug)]
70pub struct Choice {
71    pub content_filter_results: Option<ContentFilterResults>,
72    pub delta: Option<Delta>,
73    pub finish_reason: Option<String>,
74    pub index: i32,
75}
76
77/// Content in a streamed response chunk Choice
78#[derive(Deserialize, Serialize, Debug)]
79pub struct Delta {
80    pub content: Option<String>,
81    pub context: Option<DeltaContext>,
82    pub tool_calls: Option<Vec<ToolCallInDelta>>,
83}
84
85#[derive(Deserialize, Serialize, Debug)]
86pub struct DeltaContext {
87    pub citations: Vec<Citation>,
88}
89
90/// A streamed tool call from Azure
91#[derive(Deserialize, Serialize, Debug)]
92pub struct ToolCallInDelta {
93    pub id: Option<String>,
94    pub function: DeltaTool,
95    #[serde(rename = "type")]
96    pub tool_type: Option<ToolCallType>,
97}
98
99/// Streamed tool call content
100#[derive(Deserialize, Serialize, Debug, Clone)]
101pub struct DeltaTool {
102    #[serde(default)]
103    pub arguments: String,
104    pub name: Option<String>,
105}
106
107#[derive(Serialize, Deserialize, Debug, Clone)]
108#[serde(rename_all = "snake_case")]
109pub enum ToolCallType {
110    Function,
111}
112
113#[derive(Deserialize, Serialize, Debug)]
114pub struct Citation {
115    pub content: String,
116    pub title: String,
117    pub url: String,
118    pub filepath: String,
119}
120
121/// Response received from LLM API
122#[derive(Deserialize, Serialize, Debug)]
123pub struct ResponseChunk {
124    pub choices: Vec<Choice>,
125    pub created: u64,
126    pub id: String,
127    pub model: String,
128    pub object: String,
129    pub system_fingerprint: Option<String>,
130}
131
132#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
133#[serde(rename_all = "snake_case")]
134pub enum LLMToolChoice {
135    Auto,
136}
137
138#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
139pub struct ThinkingParams {
140    pub max_completion_tokens: Option<i32>,
141    pub verbosity: Option<VerbosityLevel>,
142    pub reasoning_effort: Option<ReasoningEffortLevel>,
143    #[serde(skip_serializing_if = "Vec::is_empty")]
144    pub tools: Vec<AzureLLMToolDefinition>,
145    pub tool_choice: Option<LLMToolChoice>,
146}
147
148#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
149pub struct NonThinkingParams {
150    pub max_tokens: Option<i32>,
151    pub temperature: Option<f32>,
152    pub top_p: Option<f32>,
153    pub frequency_penalty: Option<f32>,
154    pub presence_penalty: Option<f32>,
155}
156
157#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
158#[serde(untagged)]
159pub enum LLMRequestParams {
160    Thinking(ThinkingParams),
161    NonThinking(NonThinkingParams),
162}
163
164#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
165#[serde(rename_all = "snake_case")]
166pub enum JSONType {
167    JsonSchema,
168    Object,
169    Array,
170    String,
171}
172
173/// Schema for defining structured LLM output
174#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
175pub struct JSONSchema {
176    pub name: String,
177    pub strict: bool,
178    pub schema: Schema,
179}
180
181/// Defines LLM structured output shape and types
182#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
183#[serde(rename_all = "camelCase")]
184pub struct Schema {
185    #[serde(rename = "type")]
186    /// Type of the schema, should be Object
187    pub type_field: JSONType,
188    // only array-type properties are supported for now
189    pub properties: HashMap<String, ArrayProperty>,
190    /// All 'properties' keys must be included in this 'required' list
191    pub required: Vec<String>,
192    /// additionalProperties should always be 'false'
193    pub additional_properties: bool,
194}
195
196#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
197pub struct ArrayProperty {
198    #[serde(rename = "type")]
199    pub type_field: JSONType,
200    pub items: ArrayItem,
201}
202
203#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
204pub struct ArrayItem {
205    #[serde(rename = "type")]
206    pub type_field: JSONType,
207}
208
209#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
210pub struct LLMRequestResponseFormatParam {
211    #[serde(rename = "type")]
212    pub format_type: JSONType, //should be JsonSchema
213    pub json_schema: JSONSchema,
214}
215
216#[derive(Serialize, Deserialize, Debug, Clone)]
217pub struct LLMRequest {
218    pub messages: Vec<APIMessage>,
219    #[serde(skip_serializing_if = "Vec::is_empty", default)]
220    pub data_sources: Vec<DataSource>,
221    #[serde(flatten)]
222    pub params: LLMRequestParams,
223    #[serde(skip_serializing_if = "Option::is_none")]
224    pub response_format: Option<LLMRequestResponseFormatParam>,
225    pub stop: Option<String>,
226}
227
228impl LLMRequest {
229    pub async fn build_and_insert_incoming_message_to_db(
230        conn: &mut PgConnection,
231        chatbot_configuration_id: Uuid,
232        conversation_id: Uuid,
233        message: &str,
234        app_config: &ApplicationConfiguration,
235    ) -> anyhow::Result<(Self, i32, i32)> {
236        let index_name = Url::parse(&app_config.base_url)?
237            .host_str()
238            .expect("BASE_URL must have a host")
239            .replace(".", "-");
240
241        let configuration =
242            models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
243
244        let conversation_messages =
245            models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
246                .await?;
247
248        let new_order_number = conversation_messages
249            .iter()
250            .map(|m| m.order_number)
251            .max()
252            .unwrap_or(0)
253            + 1;
254
255        let new_message = models::chatbot_conversation_messages::insert(
256            conn,
257            ChatbotConversationMessage {
258                id: Uuid::new_v4(),
259                created_at: Utc::now(),
260                updated_at: Utc::now(),
261                deleted_at: None,
262                conversation_id,
263                message: Some(message.to_string()),
264                message_role: MessageRole::User,
265                message_is_complete: true,
266                used_tokens: estimate_tokens(message),
267                order_number: new_order_number,
268                tool_call_fields: vec![],
269                tool_output: None,
270            },
271        )
272        .await?;
273
274        let mut api_chat_messages: Vec<APIMessage> = conversation_messages
275            .into_iter()
276            .map(APIMessage::try_from)
277            .collect::<ChatbotResult<Vec<_>>>()?;
278
279        // put new user message into the messages list
280        api_chat_messages.push(new_message.clone().try_into()?);
281
282        api_chat_messages.insert(
283            0,
284            APIMessage {
285                role: MessageRole::System,
286                fields: APIMessageKind::Text(APIMessageText {
287                    content: configuration.prompt.clone(),
288                }),
289            },
290        );
291
292        let data_sources = if configuration.use_azure_search {
293            let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
294                anyhow::anyhow!("Azure configuration is missing from the application configuration")
295            })?;
296
297            let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
298                anyhow::anyhow!(
299                    "Azure search configuration is missing from the Azure configuration"
300                )
301            })?;
302
303            let query_type = if configuration.use_semantic_reranking {
304                "vector_semantic_hybrid"
305            } else {
306                "vector_simple_hybrid"
307            };
308
309            // if there are data sources, the message history might contain incompatible
310            // tool call and result messages. Remove tool call messages and turn tool
311            // response messages into role=assistant messages with the tool output as
312            // text content.
313            api_chat_messages = api_chat_messages
314                .into_iter()
315                .filter(|m| !matches!(m.fields, APIMessageKind::ToolCall(_)))
316                .map(|m| match m.fields {
317                    APIMessageKind::ToolResponse(r) => APIMessage {
318                        role: MessageRole::Assistant,
319                        fields: APIMessageKind::Text(APIMessageText { content: r.content }),
320                    },
321                    _ => m,
322                })
323                .collect();
324
325            vec![DataSource {
326                data_type: "azure_search".to_string(),
327                parameters: DataSourceParameters {
328                    endpoint: search_config.search_endpoint.to_string(),
329                    authentication: DataSourceParametersAuthentication {
330                        auth_type: "api_key".to_string(),
331                        key: search_config.search_api_key.clone(),
332                    },
333                    semantic_configuration: format!("{}-semantic-configuration", &index_name),
334                    index_name,
335                    query_type: query_type.to_string(),
336                    embedding_dependency: EmbeddingDependency {
337                        dep_type: "deployment_name".to_string(),
338                        deployment_name: search_config.vectorizer_deployment_id.clone(),
339                    },
340                    in_scope: false,
341                    top_n_documents: 15,
342                    strictness: 3,
343                    filter: Some(
344                        SearchFilter::eq("course_id", configuration.course_id.to_string())
345                            .to_odata()?,
346                    ),
347                    fields_mapping: FieldsMapping {
348                        content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
349                        content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
350                        filepath_field: "filepath".to_string(),
351                        title_field: "title".to_string(),
352                        url_field: "url".to_string(),
353                        vector_fields: vec!["text_vector".to_string()],
354                    },
355                },
356            }]
357        } else {
358            Vec::new()
359        };
360
361        let tools = if configuration.use_tools {
362            get_chatbot_tool_definitions()
363        } else {
364            Vec::new()
365        };
366
367        let serialized_messages = serde_json::to_string(&api_chat_messages)?;
368        let request_estimated_tokens = estimate_tokens(&serialized_messages);
369
370        let params = if configuration.thinking_model {
371            LLMRequestParams::Thinking(ThinkingParams {
372                max_completion_tokens: Some(configuration.max_completion_tokens),
373                reasoning_effort: Some(configuration.reasoning_effort),
374                verbosity: Some(configuration.verbosity),
375                tools,
376                tool_choice: if configuration.use_tools {
377                    Some(LLMToolChoice::Auto)
378                } else {
379                    None
380                },
381            })
382        } else {
383            LLMRequestParams::NonThinking(NonThinkingParams {
384                max_tokens: Some(configuration.response_max_tokens),
385                temperature: Some(configuration.temperature),
386                top_p: Some(configuration.top_p),
387                frequency_penalty: Some(configuration.frequency_penalty),
388                presence_penalty: Some(configuration.presence_penalty),
389            })
390        };
391
392        Ok((
393            Self {
394                messages: api_chat_messages,
395                data_sources,
396                params,
397                response_format: None,
398                stop: None,
399            },
400            new_message.order_number,
401            request_estimated_tokens,
402        ))
403    }
404
405    pub async fn update_messages_to_db(
406        mut self,
407        conn: &mut PgConnection,
408        new_msgs: Vec<APIMessage>,
409        conversation_id: Uuid,
410        mut order_number: i32,
411    ) -> anyhow::Result<(Self, i32)> {
412        for m in new_msgs {
413            let converted_msg = m.to_chatbot_conversation_message(conversation_id, order_number)?;
414            chatbot_conversation_messages::insert(conn, converted_msg).await?;
415            self.messages.push(m);
416            order_number += 1;
417        }
418        Ok((self, order_number))
419    }
420}
421
422#[derive(Serialize, Deserialize, Debug, Clone)]
423pub struct DataSource {
424    #[serde(rename = "type")]
425    pub data_type: String,
426    pub parameters: DataSourceParameters,
427}
428
429#[derive(Serialize, Deserialize, Debug, Clone)]
430pub struct DataSourceParameters {
431    pub endpoint: String,
432    pub authentication: DataSourceParametersAuthentication,
433    pub index_name: String,
434    pub query_type: String,
435    pub embedding_dependency: EmbeddingDependency,
436    pub in_scope: bool,
437    pub top_n_documents: i32,
438    pub strictness: i32,
439    #[serde(skip_serializing_if = "Option::is_none")]
440    pub filter: Option<String>,
441    pub fields_mapping: FieldsMapping,
442    pub semantic_configuration: String,
443}
444
445#[derive(Serialize, Deserialize, Debug, Clone)]
446pub struct DataSourceParametersAuthentication {
447    #[serde(rename = "type")]
448    pub auth_type: String,
449    pub key: String,
450}
451
452#[derive(Serialize, Deserialize, Debug, Clone)]
453pub struct EmbeddingDependency {
454    #[serde(rename = "type")]
455    pub dep_type: String,
456    pub deployment_name: String,
457}
458
459#[derive(Serialize, Deserialize, Debug, Clone)]
460pub struct FieldsMapping {
461    pub content_fields_separator: String,
462    pub content_fields: Vec<String>,
463    pub filepath_field: String,
464    pub title_field: String,
465    pub url_field: String,
466    pub vector_fields: Vec<String>,
467}
468
469#[derive(Serialize, Deserialize, Debug)]
470pub struct ChatResponse {
471    pub text: String,
472}
473
474/// Custom stream that encapsulates both the response stream and the cancellation guard. Makes sure that the guard is always dropped when the stream is dropped.
475#[pin_project]
476struct GuardedStream<S> {
477    guard: RequestCancelledGuard,
478    #[pin]
479    stream: S,
480}
481
482impl<S> GuardedStream<S> {
483    fn new(guard: RequestCancelledGuard, stream: S) -> Self {
484        Self { guard, stream }
485    }
486}
487
488impl<S> Stream for GuardedStream<S>
489where
490    S: Stream<Item = anyhow::Result<Bytes>> + Send,
491{
492    type Item = S::Item;
493
494    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
495        let this = self.project();
496        this.stream.poll_next(cx)
497    }
498}
499
500/// A LinesStream that is peekable. Needed to determine which type of LLM response is
501/// being received.
502type PeekableLinesStream<'a> = Pin<
503    Box<Peekable<LinesStream<StreamReader<BoxStream<'a, Result<Bytes, std::io::Error>>, Bytes>>>>,
504>;
505pub enum ResponseStreamType<'a> {
506    Toolcall(PeekableLinesStream<'a>),
507    TextResponse(PeekableLinesStream<'a>),
508}
509
510struct RequestCancelledGuard {
511    response_message_id: Uuid,
512    received_string: Arc<Mutex<Vec<String>>>,
513    pool: PgPool,
514    done: Arc<AtomicBool>,
515    request_estimated_tokens: i32,
516}
517
518impl Drop for RequestCancelledGuard {
519    fn drop(&mut self) {
520        if self.done.load(atomic::Ordering::Relaxed) {
521            return;
522        }
523        warn!("Request was not cancelled. Cleaning up.");
524        let response_message_id = self.response_message_id;
525        let received_string = self.received_string.clone();
526        let pool = self.pool.clone();
527        let request_estimated_tokens = self.request_estimated_tokens;
528        tokio::spawn(async move {
529            info!("Verifying the received message has been handled");
530            let mut conn = pool.acquire().await.expect("Could not acquire connection");
531            let full_response_text = received_string.lock().await;
532            if full_response_text.is_empty() {
533                info!("No response received. Deleting the response message");
534                models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
535                    .await
536                    .expect("Could not delete response message");
537                return;
538            }
539            info!("Response received but not completed. Saving the text received so far.");
540            let full_response_as_string = full_response_text.join("");
541            let estimated_cost = estimate_tokens(&full_response_as_string);
542            info!(
543                "End of chatbot response stream. Estimated cost: {}. Response: {}",
544                estimated_cost, full_response_as_string
545            );
546
547            // Update with request_estimated_tokens + estimated_cost
548            models::chatbot_conversation_messages::update(
549                &mut conn,
550                response_message_id,
551                &full_response_as_string,
552                true,
553                request_estimated_tokens + estimated_cost,
554            )
555            .await
556            .expect("Could not update response message");
557        });
558    }
559}
560
561pub async fn make_request_and_stream<'a>(
562    chat_request: LLMRequest,
563    model_name: &str,
564    app_config: &ApplicationConfiguration,
565) -> anyhow::Result<ResponseStreamType<'a>> {
566    let response = make_streaming_llm_request(chat_request, model_name, app_config).await?;
567
568    trace!("Receiving chat response with {:?}", response.version());
569
570    if !response.status().is_success() {
571        let status = response.status();
572        let error_message = response.text().await?;
573        return Err(anyhow::anyhow!(
574            "Failed to send chat request. Status: {}. Error: {}",
575            status,
576            error_message
577        ));
578    }
579
580    let stream = response
581        .bytes_stream()
582        .map_err(std::io::Error::other)
583        .boxed();
584    let reader = StreamReader::new(stream);
585    let lines = reader.lines();
586    let lines_stream = LinesStream::new(lines);
587    let peekable_lines_stream = lines_stream.peekable();
588    let mut pinned_lines = Box::pin(peekable_lines_stream);
589
590    loop {
591        let line_res = pinned_lines.as_mut().peek().await;
592        match line_res {
593            None => {
594                break;
595            }
596            Some(Err(e)) => {
597                return Err(anyhow!(
598                    "There was an error streaming response from Azure: {}",
599                    e
600                ));
601            }
602            Some(Result::Ok(line)) => {
603                if !line.starts_with("data: ") {
604                    pinned_lines.next().await;
605                    continue;
606                }
607                let json_str = line.trim_start_matches("data: ");
608                let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
609                    .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {}", e))?;
610                for choice in &response_chunk.choices {
611                    if let Some(d) = &choice.delta {
612                        if d.content.is_some() || d.context.is_some() {
613                            return Ok(ResponseStreamType::TextResponse(pinned_lines));
614                        } else if let Some(_calls) = &d.tool_calls {
615                            return Ok(ResponseStreamType::Toolcall(pinned_lines));
616                        } else if d.content.is_none() {
617                            pinned_lines.next().await;
618                            continue;
619                        }
620                    }
621                }
622                pinned_lines.next().await;
623            }
624        }
625    }
626    Err(Error::msg(
627        "The response received from Azure had an unexpected shape and couldn't be parsed"
628            .to_string(),
629    ))
630}
631
632/// Streams and parses a LLM response from Azure that contains function calls.
633/// Calls the functions and returns a Vec of function results to be sent to Azure.
634pub async fn parse_tool<'a>(
635    conn: &mut PgConnection,
636    mut lines: PeekableLinesStream<'a>,
637    user_context: &ChatbotUserContext,
638) -> anyhow::Result<Vec<APIMessage>> {
639    let mut function_name_id_args: Vec<(String, String, String)> = vec![];
640    let mut currently_streamed_function_name_id: Option<(String, String)> = None;
641    let mut currently_streamed_function_args = vec![];
642    let mut messages = vec![];
643
644    trace!("Parsing tool calls...");
645
646    while let Some(val) = lines.next().await {
647        let line = val?;
648        if !line.to_owned().starts_with("data: ") {
649            continue;
650        }
651        let json_str = line.trim_start_matches("data: ");
652        if json_str.trim() == "[DONE]" {
653            // the stream ended
654            if function_name_id_args.is_empty() {
655                return Err(anyhow::anyhow!(
656                    "The LLM response was supposed to contain function calls, but no function calls were found"
657                ));
658            }
659            let mut assistant_tool_calls = Vec::new();
660            let mut tool_result_msgs = Vec::new();
661
662            for (name, id, args) in function_name_id_args.iter() {
663                let tool = get_chatbot_tool(conn, name, args, user_context).await?;
664
665                assistant_tool_calls.push(APIToolCall {
666                    function: APITool {
667                        name: name.to_owned(),
668                        arguments: serde_json::to_string(tool.get_arguments())?,
669                    },
670                    id: id.to_owned(),
671                    tool_type: ToolCallType::Function,
672                });
673                tool_result_msgs.push(APIMessage {
674                    role: MessageRole::Tool,
675                    fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
676                        content: tool.get_tool_output(),
677                        name: name.to_owned(),
678                        tool_call_id: id.to_owned(),
679                    }),
680                })
681            }
682            // insert all tool calls made by the bot as one message into the messages
683            messages.push(APIMessage {
684                role: MessageRole::Assistant,
685                fields: APIMessageKind::ToolCall(APIMessageToolCall {
686                    tool_calls: assistant_tool_calls,
687                }),
688            });
689            // add tool call output messages to the messages
690            messages.extend(tool_result_msgs);
691            break;
692        }
693        let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
694            .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {} {}", e, json_str))?;
695        for choice in &response_chunk.choices {
696            if Some("tool_calls".to_string()) == choice.finish_reason {
697                // the stream is finished for now because of "tool_calls"
698                // so if we're still streaming some func call, finish it and store it
699                if let Some((name, id)) = &currently_streamed_function_name_id {
700                    // we have streamed some func call and args so let's join the args
701                    // and save the call
702                    let fn_args = currently_streamed_function_args.join("");
703                    function_name_id_args.push((
704                        name.to_owned(),
705                        id.to_owned(),
706                        fn_args.to_owned(),
707                    ));
708                    currently_streamed_function_args.clear();
709                    currently_streamed_function_name_id = None;
710                    // after this chunk, there is assumed to be a chunk that just has
711                    // content "[DONE]", we'll process the func calls at that point.
712                }
713            }
714            if let Some(delta) = &choice.delta
715                && let Some(tool_calls) = &delta.tool_calls
716            {
717                // this chunk has tool call data
718                for call in tool_calls {
719                    if let (Some(name), Some(id)) = (&call.function.name, &call.id) {
720                        // if this chunk has a tool name and id, then a new call is made.
721                        // if there is previously streamed args, then their streaming is
722                        // complete, let's join and save them before processing this new
723                        // call.
724                        if let Some((name_prev, id_prev)) = currently_streamed_function_name_id {
725                            let fn_args = currently_streamed_function_args.join("");
726                            function_name_id_args.push((
727                                name_prev.to_owned(),
728                                id_prev.to_owned(),
729                                fn_args,
730                            ));
731                            currently_streamed_function_args.clear();
732                        }
733                        // set the tool name and id from this chunk to currently_streamed
734                        // and save any arguments to currently_streamed_function_args
735                        // until the stream is complete or a new call is made.
736                        currently_streamed_function_name_id =
737                            Some((name.to_owned(), id.to_owned()));
738                    };
739                    // always save any streamed function args. it can be an empty string
740                    // but that's ok.
741                    currently_streamed_function_args.push(call.function.arguments.clone());
742                }
743            }
744        }
745    }
746    Ok(messages)
747}
748
749/// Streams and parses a LLM response from Azure that contains a text response.
750pub async fn parse_and_stream_to_user<'a>(
751    conn: &mut PgConnection,
752    mut lines: PeekableLinesStream<'a>,
753    conversation_id: Uuid,
754    response_order_number: i32,
755    pool: PgPool,
756    request_estimated_tokens: i32,
757) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send + 'a>>> {
758    // insert the to-be-streamed bot text response to db
759    let response_message = models::chatbot_conversation_messages::insert(
760        conn,
761        ChatbotConversationMessage {
762            id: Uuid::new_v4(),
763            created_at: Utc::now(),
764            updated_at: Utc::now(),
765            deleted_at: None,
766            conversation_id,
767            message: Some("".to_string()),
768            message_role: MessageRole::Assistant,
769            message_is_complete: false,
770            used_tokens: request_estimated_tokens,
771            order_number: response_order_number,
772            tool_call_fields: vec![],
773            tool_output: None,
774        },
775    )
776    .await?;
777
778    let done = Arc::new(AtomicBool::new(false));
779    let full_response_text = Arc::new(Mutex::new(Vec::new()));
780    // Instantiate the guard before creating the stream.
781    let guard = RequestCancelledGuard {
782        response_message_id: response_message.id,
783        received_string: full_response_text.clone(),
784        pool: pool.clone(),
785        done: done.clone(),
786        request_estimated_tokens,
787    };
788
789    trace!("Parsing stream to user...");
790
791    let response_stream = async_stream::try_stream! {
792        while let Some(val) = lines.next().await {
793            let line = val?;
794            if !line.starts_with("data: ") {
795                continue;
796            }
797            let mut full_response_text = full_response_text.lock().await;
798            let json_str = line.trim_start_matches("data: ");
799            if json_str.trim() == "[DONE]" {
800                let full_response_as_string = full_response_text.join("");
801                let estimated_cost = estimate_tokens(&full_response_as_string);
802                trace!(
803                    "End of chatbot response stream. Estimated cost: {}. Response: {}",
804                    estimated_cost, full_response_as_string
805                );
806                done.store(true, atomic::Ordering::Relaxed);
807                let mut conn = pool.acquire().await?;
808                models::chatbot_conversation_messages::update(
809                    &mut conn,
810                    response_message.id,
811                    &full_response_as_string,
812                    true,
813                    request_estimated_tokens + estimated_cost,
814                ).await?;
815                break;
816            }
817            let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
818                anyhow::anyhow!("Failed to parse response chunk: {}", e)
819            })?;
820
821            for choice in &response_chunk.choices {
822                if let Some(delta) = &choice.delta {
823                    if let Some(content) = &delta.content {
824                        full_response_text.push(content.clone());
825                        let response = ChatResponse { text: content.clone() };
826                        let response_as_string = serde_json::to_string(&response)?;
827                        yield Bytes::from(response_as_string);
828                        yield Bytes::from("\n");
829                    }
830                    if let Some(context) = &delta.context {
831                        let mut conn = pool.acquire().await?;
832                        for (idx, cit) in context.citations.iter().enumerate() {
833                            let content = truncate_utf8_at_boundary(&cit.content, 255).to_string();
834                            let split = content.split_once(CONTENT_FIELD_SEPARATOR);
835                            if split.is_none() {
836                                error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
837                            }
838                            let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
839
840                            // The title and URL come from Azure Blob Storage metadata, which was URL-encoded
841                            // (percent-encoded) because Azure Blob Storage metadata values must be ASCII-only.
842                            // We decode them back to their original UTF-8 strings before storing in the database.
843                            let decoded_title = url_decode(&cit.title)?;
844                            let decoded_url = url_decode(&cit.url)?;
845
846                            let mut page_path = PathBuf::from(&cit.filepath);
847                            page_path.set_extension("");
848                            let page_id_str = page_path.file_name();
849                            let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
850                            let course_material_chapter_number = if let Some(id) = page_id {
851                                let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
852                                chapter.map(|c| c.chapter_number)
853                            } else {
854                                None
855                            };
856
857                            models::chatbot_conversation_messages_citations::insert(
858                                &mut conn, ChatbotConversationMessageCitation {
859                                    id: Uuid::new_v4(),
860                                    created_at: Utc::now(),
861                                    updated_at: Utc::now(),
862                                    deleted_at: None,
863                                    conversation_message_id: response_message.id,
864                                    conversation_id: response_message.conversation_id,
865                                    course_material_chapter_number,
866                                    title: decoded_title,
867                                    content: cleaned_content,
868                                    document_url: decoded_url,
869                                    citation_number: (idx+1) as i32,
870                                }
871                            ).await?;
872                        }
873                    }
874                }
875            }
876        }
877
878        if !done.load(atomic::Ordering::Relaxed) {
879            Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
880        }
881    };
882
883    // Encapsulate the stream and the guard within GuardedStream. This moves the request guard into the stream and ensures that it is dropped when the stream is dropped.
884    // This way we do cleanup only when the stream is dropped and not when this function returns.
885    let guarded_stream = GuardedStream::new(guard, response_stream);
886
887    // Box and pin the GuardedStream to satisfy the Unpin requirement
888    Ok(Box::pin(guarded_stream))
889}
890
891pub async fn send_chat_request_and_parse_stream(
892    conn: &mut PgConnection,
893    pool: PgPool,
894    app_config: &ApplicationConfiguration,
895    chatbot_configuration_id: Uuid,
896    conversation_id: Uuid,
897    message: &str,
898    user_context: ChatbotUserContext,
899) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
900    let (mut chat_request, new_message_order_number, request_estimated_tokens) =
901        LLMRequest::build_and_insert_incoming_message_to_db(
902            conn,
903            chatbot_configuration_id,
904            conversation_id,
905            message,
906            app_config,
907        )
908        .await?;
909
910    let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
911        conn,
912        chatbot_configuration_id,
913    )
914    .await?;
915
916    let mut next_message_order_number = new_message_order_number + 1;
917    let mut max_iterations_left = 15;
918
919    loop {
920        max_iterations_left -= 1;
921        if max_iterations_left == 0 {
922            error!("Maximum tool call iterations exceeded");
923            return Err(anyhow::anyhow!(
924                "Maximum tool call iterations exceeded. The LLM may be stuck in a loop."
925            ));
926        }
927
928        let response_type =
929            make_request_and_stream(chat_request.clone(), &model.deployment_name, app_config)
930                .await?;
931
932        let new_tool_msgs = match response_type {
933            ResponseStreamType::Toolcall(stream) => parse_tool(conn, stream, &user_context).await?,
934            ResponseStreamType::TextResponse(stream) => {
935                return parse_and_stream_to_user(
936                    conn,
937                    stream,
938                    conversation_id,
939                    next_message_order_number,
940                    pool,
941                    request_estimated_tokens,
942                )
943                .await;
944            }
945        };
946        (chat_request, next_message_order_number) = chat_request
947            .update_messages_to_db(
948                conn,
949                new_tool_msgs,
950                conversation_id,
951                next_message_order_number,
952            )
953            .await?;
954    }
955}