Skip to main content

headless_lms_chatbot/
azure_chatbot.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::{
4    Arc,
5    atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use anyhow::{Error, anyhow};
10use bytes::Bytes;
11use chrono::Utc;
12use futures::stream::{BoxStream, Peekable};
13use futures::{Stream, StreamExt, TryStreamExt};
14use headless_lms_base::config::ApplicationConfiguration;
15use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
16use headless_lms_models::chatbot_conversation_message_messages::{
17    ChatbotConversationMessageMessage, MessageRole,
18};
19use headless_lms_models::chatbot_conversation_messages::{
20    self, ChatbotConversationMessage, Message,
21};
22use pin_project::pin_project;
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use sqlx::PgPool;
26use tokio::{io::AsyncBufReadExt, sync::Mutex};
27use tokio_stream::wrappers::LinesStream;
28use tokio_util::io::StreamReader;
29use tracing::trace;
30use url::Url;
31
32use crate::chatbot_error::ChatbotResult;
33use crate::chatbot_tools::provider_tools::azure_ai_search::get_azure_ai_search_tool_definition;
34use crate::chatbot_tools::{
35    AzureLLMToolDefinition, ChatbotTool, get_chatbot_tool, get_chatbot_tool_definitions,
36};
37use crate::citations::chatbot_cited_documents_to_citations;
38use crate::llm_utils::{
39    APIInputMessage, APIOutputMessage, MessageContent, estimate_tokens, get_params_for_model,
40    make_streaming_llm_request,
41};
42
43use crate::prelude::*;
44
45pub const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
46
47enum ParsedResponseLine {
48    Event(String),
49    Data(ResponseOutput),
50}
51
52impl ParsedResponseLine {
53    pub fn parse(input: &str) -> ChatbotResult<Option<Self>> {
54        if input.starts_with("event: ") {
55            let event_type = input.trim_start_matches("event: ").to_string();
56            Ok(Some(ParsedResponseLine::Event(event_type)))
57        } else if input.starts_with("data: ") {
58            let data = input.trim_start_matches("data: ").to_string();
59            let response_output =
60                serde_json::from_str::<ResponseOutput>(&data).map_err(ChatbotError::from)?;
61            Ok(Some(ParsedResponseLine::Data(response_output)))
62        } else {
63            Ok(None)
64        }
65    }
66}
67
68/// Context about the user and course for a chatbot interaction.
69/// Passed to tool implementations so they can access user-specific data.
70pub struct ChatbotUserContext {
71    pub user_id: Uuid,
72    pub course_id: Uuid,
73    pub course_name: String,
74}
75
76#[derive(Deserialize, Serialize, Debug)]
77pub struct ContentFilterResults {
78    pub hate: Option<ContentFilter>,
79    pub self_harm: Option<ContentFilter>,
80    pub sexual: Option<ContentFilter>,
81    pub violence: Option<ContentFilter>,
82    //pub jailbreak: Option<ContentFilter>,
83}
84
85#[derive(Deserialize, Serialize, Debug)]
86pub struct ContentFilter {
87    pub blocked: bool,
88    pub source_type: ContentFilterSource,
89    pub content_filter_results: Vec<ContentFilterResults>,
90}
91#[derive(Deserialize, Serialize, Debug)]
92pub struct ContentFilterResult {
93    pub filtered: bool,
94    pub severity: String,
95}
96
97#[derive(Deserialize, Serialize, Debug)]
98#[serde(rename_all = "snake_case")]
99pub enum ContentFilterSource {
100    Prompt,
101    Completion,
102}
103
104/// Response received from LLM API
105#[derive(Deserialize, Serialize, Debug)]
106pub struct Response {
107    pub id: String,
108    pub error: Option<String>,
109}
110
111/// Incomplete response received from LLM API
112#[derive(Deserialize, Serialize, Debug)]
113pub struct IncompleteResponse {
114    pub id: String,
115    pub incomplete_details: IncompleteReason,
116    pub content_filters: Vec<ContentFilter>,
117}
118
119/// Response received from LLM API
120#[derive(Deserialize, Serialize, Debug)]
121pub struct IncompleteReason {
122    pub reason: String,
123}
124
125/// Streamed token of the response text
126#[derive(Deserialize, Serialize, Debug)]
127pub struct ResponseOutput {
128    pub delta: Option<String>,
129    pub item: Option<OutputItem>,
130    pub response: Option<Response>,
131}
132
133#[derive(Deserialize, Serialize, Debug, Clone)]
134#[serde(tag = "type")]
135#[serde(rename_all = "snake_case")]
136pub enum OutputItem {
137    Message {
138        response_id: String,
139        role: MessageRole,
140        content: MessageContent,
141    },
142    Reasoning {
143        response_id: String,
144        summary: Vec<ReasoningOutput>,
145    },
146    AzureAiSearchCall {
147        response_id: String,
148        call_id: String,
149        /// JSON string
150        arguments: String,
151    },
152    AzureAiSearchCallOutput {
153        response_id: String,
154        call_id: String,
155        /// JSON string
156        output: String,
157    },
158    FunctionCall {
159        response_id: String,
160        call_id: String,
161        #[serde(rename = "name")]
162        tool_name: String,
163        /// JSON string
164        arguments: String,
165    },
166    FunctionCallOutput {
167        response_id: String,
168        call_id: String,
169        output: String,
170    },
171}
172
173#[derive(Deserialize, Serialize, Debug, Clone)]
174#[serde(tag = "type")]
175#[serde(rename_all = "snake_case")]
176pub enum InputItem {
177    Message {
178        role: MessageRole,
179        content: MessageContent,
180    },
181    FunctionCall {
182        call_id: String,
183        #[serde(rename = "name")]
184        tool_name: String,
185        arguments: String,
186    },
187    FunctionCallOutput {
188        call_id: String,
189        output: String,
190    },
191}
192
193#[derive(Deserialize, Serialize, Debug, Clone)]
194pub struct AiSearchOutput {
195    pub get_urls: Vec<Url>,
196}
197
198#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
199#[serde(rename_all = "snake_case")]
200pub enum LLMToolChoice {
201    Auto,
202    None,
203}
204
205#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
206pub struct ThinkingParams {
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub reasoning: Option<Reasoning>,
209}
210
211#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
212pub struct RequestTextOptions {
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub verbosity: Option<VerbosityLevel>,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub format: Option<LLMRequestResponseFormatParam>,
217}
218#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
219pub struct Reasoning {
220    pub effort: ReasoningEffortLevel,
221    /// Option to generate a reasoning summary with desired level of info
222    pub summary: Option<SummaryType>,
223}
224
225#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
226#[serde(untagged)]
227pub enum SummaryType {
228    Concise,
229    Detailed,
230    Auto,
231}
232
233#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
234pub struct ReasoningOutput {
235    #[serde(rename = "type")]
236    pub output_type: String, //summary_text
237    pub text: String,
238}
239
240#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
241pub struct NonThinkingParams {
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub temperature: Option<f32>,
244    #[serde(skip_serializing_if = "Option::is_none")]
245    pub top_p: Option<f32>,
246    #[serde(skip_serializing_if = "Option::is_none")]
247    pub frequency_penalty: Option<f32>,
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub presence_penalty: Option<f32>,
250}
251
252#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
253pub struct MistralParams {
254    // todo
255    pub test: bool,
256}
257
258#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
259#[serde(untagged)]
260pub enum LLMRequestParams {
261    GPTThinking(ThinkingParams),
262    GPTNonThinking(NonThinkingParams),
263    Mistral(MistralParams),
264}
265
266#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
267#[serde(rename_all = "snake_case")]
268pub enum JSONType {
269    JsonSchema,
270    Object,
271    Array,
272    String,
273}
274
275/// Defines LLM structured output shape and types
276#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
277#[serde(rename_all = "camelCase")]
278pub struct Schema {
279    #[serde(rename = "type")]
280    /// Type of the schema, should be Object
281    pub type_field: JSONType,
282    // only array-type properties are supported for now
283    pub properties: HashMap<String, ArrayProperty>,
284    /// All 'properties' keys must be included in this 'required' list
285    pub required: Vec<String>,
286    /// additionalProperties should always be 'false'
287    pub additional_properties: bool,
288}
289
290#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
291pub struct ArrayProperty {
292    #[serde(rename = "type")]
293    pub type_field: JSONType,
294    pub items: ArrayItem,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
298pub struct ArrayItem {
299    #[serde(rename = "type")]
300    pub type_field: JSONType,
301}
302
303#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
304pub struct LLMRequestResponseFormatParam {
305    #[serde(rename = "type")]
306    pub format_type: JSONType, //should be JsonSchema
307    pub name: String,
308    pub schema: Schema,
309    pub strict: bool, // should be true
310}
311
312#[derive(Serialize, Deserialize, Debug, Clone)]
313pub struct LLMRequest {
314    pub input: Vec<APIInputMessage>,
315    pub model: String,
316    #[serde(skip_serializing_if = "Vec::is_empty", default)]
317    pub tools: Vec<AzureLLMToolDefinition>,
318    #[serde(skip_serializing_if = "Option::is_none")]
319    pub tool_choice: Option<LLMToolChoice>,
320    #[serde(skip_serializing_if = "Option::is_none")]
321    pub max_output_tokens: Option<i32>,
322    #[serde(skip_serializing_if = "Option::is_none")]
323    pub text: Option<RequestTextOptions>,
324    #[serde(flatten)]
325    pub params: LLMRequestParams,
326}
327
328impl LLMRequest {
329    pub async fn build_and_insert_incoming_message_to_db(
330        conn: &mut PgConnection,
331        chatbot_configuration_id: Uuid,
332        conversation_id: Uuid,
333        message: &str,
334        app_config: &ApplicationConfiguration,
335    ) -> anyhow::Result<(Self, i32)> {
336        let configuration =
337            models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
338
339        let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
340            conn,
341            chatbot_configuration_id,
342        )
343        .await?;
344
345        let conversation_messages =
346            models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
347                .await?;
348
349        let new_order_number = conversation_messages
350            .iter()
351            .map(|m| m.order_number)
352            .max()
353            .unwrap_or(0)
354            + 1;
355
356        let new_message = models::chatbot_conversation_messages::insert(
357            conn,
358            ChatbotConversationMessage {
359                id: Uuid::new_v4(),
360                order_number: new_order_number,
361                created_at: Utc::now(),
362                updated_at: Utc::now(),
363                deleted_at: None,
364                conversation_id,
365                message: Message::Text(ChatbotConversationMessageMessage {
366                    text: message.to_string(),
367                    message_role: MessageRole::User,
368                    message_is_complete: true,
369                    used_tokens: estimate_tokens(message),
370                    ..Default::default()
371                }),
372            },
373        )
374        .await?;
375
376        let mut api_chat_messages: Vec<APIInputMessage> = conversation_messages
377            .into_iter()
378            .filter_map(|m| match m.message {
379                Message::Reasoning(..) => None,
380                _ => Some(APIInputMessage::try_from(m)),
381            })
382            .collect::<ChatbotResult<Vec<_>>>()?;
383
384        // put new user message into the messages list
385        api_chat_messages.push(new_message.clone().try_into()?);
386
387        api_chat_messages.insert(
388            0,
389            APIInputMessage {
390                message_type: InputItem::Message {
391                    role: MessageRole::System,
392                    content: MessageContent::Text(configuration.prompt.clone()),
393                },
394            },
395        );
396
397        let mut tools = if configuration.use_tools {
398            get_chatbot_tool_definitions()
399        } else {
400            Vec::new()
401        };
402
403        if configuration.use_azure_search {
404            tools.extend(vec![AzureLLMToolDefinition::Search(
405                get_azure_ai_search_tool_definition(
406                    app_config,
407                    configuration.course_id,
408                    configuration.use_semantic_reranking,
409                )?,
410            )]);
411        };
412
413        let tool_choice = if configuration.use_azure_search || configuration.use_tools {
414            Some(LLMToolChoice::Auto)
415        } else {
416            None
417        };
418
419        let serialized_messages = serde_json::to_string(&api_chat_messages)?;
420        let request_estimated_tokens = estimate_tokens(&serialized_messages);
421
422        let params = get_params_for_model(&model, &configuration);
423
424        Ok((
425            Self {
426                input: api_chat_messages,
427                model: model.model,
428                max_output_tokens: Some(configuration.max_output_tokens),
429                tools,
430                tool_choice,
431                text: Some(RequestTextOptions {
432                    verbosity: Some(configuration.verbosity),
433                    format: None,
434                }),
435                params,
436            },
437            request_estimated_tokens,
438        ))
439    }
440}
441
442#[derive(Serialize, Deserialize, Debug, Clone)]
443pub struct ChatResponse {
444    pub text: String,
445}
446
447/// Custom stream that encapsulates both the response stream and the cancellation guard. Makes sure that the guard is always dropped when the stream is dropped.
448#[pin_project]
449struct GuardedStream<S> {
450    guard: RequestCancelledGuard,
451    #[pin]
452    stream: S,
453}
454
455impl<S> GuardedStream<S> {
456    fn new(guard: RequestCancelledGuard, stream: S) -> Self {
457        Self { guard, stream }
458    }
459}
460
461impl<S> Stream for GuardedStream<S>
462where
463    S: Stream<Item = anyhow::Result<Bytes>> + Send,
464{
465    type Item = S::Item;
466
467    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
468        let this = self.project();
469        this.stream.poll_next(cx)
470    }
471}
472
473/// A LinesStream that is peekable. Needed to determine which type of LLM response is
474/// being received.
475type PeekableLinesStream<'a> = Pin<
476    Box<Peekable<LinesStream<StreamReader<BoxStream<'a, Result<Bytes, std::io::Error>>, Bytes>>>>,
477>;
478pub enum ResponseStreamType<'a> {
479    Toolcall(PeekableLinesStream<'a>),
480    TextResponse(PeekableLinesStream<'a>),
481}
482
483struct RequestCancelledGuard {
484    response_message_id: Uuid,
485    received_string: Arc<Mutex<Vec<String>>>,
486    pool: PgPool,
487    done: Arc<AtomicBool>,
488    request_estimated_tokens: i32,
489}
490
491impl Drop for RequestCancelledGuard {
492    fn drop(&mut self) {
493        if self.done.load(atomic::Ordering::Relaxed) {
494            return;
495        }
496        warn!("Request was not cancelled. Cleaning up.");
497        let response_message_id = self.response_message_id;
498        let received_string = self.received_string.clone();
499        let pool = self.pool.clone();
500        let request_estimated_tokens = self.request_estimated_tokens;
501        tokio::spawn(async move {
502            info!("Verifying the received message has been handled");
503            let mut conn = pool.acquire().await.expect("Could not acquire connection");
504            let full_response_text = received_string.lock().await;
505            if full_response_text.is_empty() {
506                info!("No response received. Deleting the response message");
507                models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
508                    .await
509                    .expect("Could not delete response message");
510                return;
511            }
512            info!("Response received but not completed. Saving the text received so far.");
513            let full_response_as_string = full_response_text.join("");
514            let estimated_cost = estimate_tokens(&full_response_as_string);
515            info!(
516                "End of chatbot response stream. Estimated cost: {}. Response: {}",
517                estimated_cost, full_response_as_string
518            );
519
520            // Update with request_estimated_tokens + estimated_cost
521            models::chatbot_conversation_message_messages::update(
522                &mut conn,
523                response_message_id,
524                &full_response_as_string,
525                true,
526                request_estimated_tokens + estimated_cost,
527            )
528            .await
529            .expect("Could not update response message");
530        });
531    }
532}
533
534/// Creates a stream with the LLMRequest and processes received OutputItems until receiving
535/// a response text or tool call.
536/// Returns:
537///     response id created by Azure (String),
538///     ResponseStreamType (type: response text or tool call) containing the created stream
539pub async fn make_request_and_stream<'a>(
540    conn: &mut PgConnection,
541    chat_request: LLMRequest,
542    conversation_id: Uuid,
543    app_config: &ApplicationConfiguration,
544) -> anyhow::Result<(String, ResponseStreamType<'a>)> {
545    let response = make_streaming_llm_request(chat_request, app_config).await?;
546
547    trace!("Receiving chat response with {:?}", response.version());
548
549    if !response.status().is_success() {
550        let status = response.status();
551        let error_message = response.text().await?;
552        return Err(anyhow::anyhow!(
553            "Failed to send chat request. Status: {}. Error: {}",
554            status,
555            error_message
556        ));
557    }
558
559    let stream = response
560        .bytes_stream()
561        .map_err(std::io::Error::other)
562        .boxed();
563    let reader = StreamReader::new(stream);
564    let lines = reader.lines();
565    let lines_stream = LinesStream::new(lines);
566    let peekable_lines_stream = lines_stream.peekable();
567    let mut pinned_lines = Box::pin(peekable_lines_stream);
568
569    // empty string because when event: response.created, it will be set as the correct
570    // value, and this event is the first event of the stream.
571    let mut response_id = "".to_string();
572    let mut output_item_incoming = false;
573    let mut response_created_incoming = false;
574    let mut error_incoming = false;
575
576    loop {
577        let line_res = pinned_lines.as_mut().peek().await;
578        match line_res {
579            None => {
580                break;
581            }
582            Some(Err(e)) => {
583                return Err(anyhow!(
584                    "There was an error streaming response from Azure: {}. Response id: {}",
585                    e,
586                    response_id
587                ));
588            }
589            Some(Result::Ok(line)) => {
590                match ParsedResponseLine::parse(line)? {
591                    Some(ParsedResponseLine::Event(event_type)) => {
592                        trace!("Event: {event_type}");
593                        match event_type.as_str() {
594                            "response.created" => {
595                                response_created_incoming = true;
596                            }
597                            "response.output_item.done" => {
598                                output_item_incoming = true;
599                            }
600                            "response.function_call_arguments.delta" => {
601                                if response_id.is_empty() {
602                                    return Err(anyhow::anyhow!(
603                                        "No response_id found! This should never happen!"
604                                    ));
605                                }
606                                return Ok((
607                                    response_id,
608                                    ResponseStreamType::Toolcall(pinned_lines),
609                                ));
610                            }
611                            "response.output_text.delta" => {
612                                return Ok((
613                                    response_id,
614                                    ResponseStreamType::TextResponse(pinned_lines),
615                                ));
616                            }
617                            "response.error" => {
618                                error_incoming = true;
619                            }
620                            _ => {}
621                        }
622                    }
623                    Some(ParsedResponseLine::Data(response_output)) => {
624                        if error_incoming
625                            && let Some(response) = &response_output.response
626                            && let Some(error) = &response.error
627                        {
628                            Err(chatbot_err!(
629                                StreamingError,
630                                format!(
631                                    "Error received from the API: {}. Response id: {}",
632                                    error, response.id
633                                )
634                            ))?
635                        };
636                        if response_created_incoming {
637                            let res = response_output.response.ok_or(chatbot_err!(
638                                DeserializationError,
639                                "Expected response object"
640                            ))?;
641                            response_id = res.id;
642                            response_created_incoming = false;
643                        }
644                        if output_item_incoming {
645                            let item = response_output.item.ok_or(chatbot_err!(
646                                DeserializationError,
647                                "Expected response output item"
648                            ))?;
649                            // put in input
650                            process_output_item(conn, item, conversation_id, app_config).await?;
651                            output_item_incoming = false;
652                        }
653                    }
654                    None => {}
655                }
656                pinned_lines.next().await;
657                continue;
658            }
659        }
660    }
661    Err(Error::msg(format!(
662        "The response received from Azure ended unexpectedly. Response id: {response_id}"
663    )))
664}
665
666/// For saving output items that are not text messages or function calls, i.e. that
667/// don't need further processing and are not streamed to the user.
668/// Saves reasoning and Azure AI items.
669pub async fn process_output_item(
670    conn: &mut PgConnection,
671    item: OutputItem,
672    conversation_id: Uuid,
673    app_config: &ApplicationConfiguration,
674) -> ChatbotResult<ChatbotConversationMessage> {
675    match item {
676        OutputItem::AzureAiSearchCall { .. } | OutputItem::Reasoning { .. } => {
677            let message = APIOutputMessage { message_type: item }
678                .to_chatbot_conversation_message(conversation_id)?;
679
680            ChatbotResult::Ok(chatbot_conversation_messages::insert(conn, message).await?)
681        }
682        OutputItem::AzureAiSearchCallOutput {
683            call_id,
684            output,
685            response_id,
686        } => {
687            let search_output: AiSearchOutput = serde_json::from_str(&output)?;
688            let api_key = if let Some(azure_config) = &app_config.azure_configuration
689                && let Some(search_config) = &azure_config.search_config
690            {
691                &search_config.search_api_key
692            } else {
693                return ChatbotResult::Err(chatbot_err!(
694                    Other,
695                    "Azure search configuration not found, cannot process Azure AI search output item.".to_string()
696                ));
697            };
698            let get_urls = search_output.get_urls.to_owned();
699
700            let message = APIOutputMessage {
701                message_type: OutputItem::AzureAiSearchCallOutput {
702                    call_id,
703                    output,
704                    response_id,
705                },
706            }
707            .to_chatbot_conversation_message(conversation_id)?;
708
709            let conversation_message = chatbot_conversation_messages::insert(conn, message).await?;
710
711            chatbot_cited_documents_to_citations(
712                conn,
713                app_config.test_chatbot,
714                get_urls,
715                api_key,
716                conversation_message.id,
717                conversation_id,
718            )
719            .await?;
720
721            ChatbotResult::Ok(conversation_message)
722        }
723        OutputItem::Message { .. } => {
724            // this chunk has a text message and should be streamed!
725            Err(chatbot_err!(
726                StreamingError,
727                "Unexpected message output item, it should have been streamed.".to_string()
728            ))
729        }
730        OutputItem::FunctionCall { .. } => {
731            // this chunk has tool call data andit should already be saved!!
732            Err(chatbot_err!(
733                StreamingError,
734                "Unexpected function call output item, it should have been processed.".to_string()
735            ))
736        }
737        OutputItem::FunctionCallOutput { .. } => {
738            // this chunk has tool output data
739            // we shouldn't be receiving it from the LLM!
740            // tool output is created by us!
741            Err(chatbot_err!(
742                StreamingError,
743                "Unexpected function call output item, this shouldn't happen.".to_string()
744            ))
745        }
746    }
747}
748
749/// Streams and parses a LLM response from Azure that contains function calls.
750/// Calls the functions and returns a Vec of function results to be sent to Azure.
751pub async fn parse_tool<'a>(
752    conn: &mut PgConnection,
753    mut lines: PeekableLinesStream<'a>,
754    conversation_id: Uuid,
755    user_context: &ChatbotUserContext,
756    app_config: &ApplicationConfiguration,
757) -> anyhow::Result<Vec<APIOutputMessage>> {
758    let mut function_name_id_args: Vec<(String, String, Value)> = vec![];
759    let mut messages = vec![];
760    let mut common_response_id = "".to_string();
761    let mut response_received = false;
762    let mut error_incoming = false;
763
764    trace!("Parsing tool calls...");
765
766    while let Some(val) = lines.next().await {
767        let line = val?;
768        let response_output = match ParsedResponseLine::parse(&line)? {
769            Some(ParsedResponseLine::Event(event_type)) => {
770                match event_type.as_str() {
771                    "response.completed" => {
772                        response_received = true;
773                    }
774                    "response.output_text.delta" => {
775                        return Err(anyhow::anyhow!(
776                            "Error: Received response text while parsing tool calls. Either the tool call parsing failed or the LLM responded in an unexpected way."
777                        ));
778                    }
779                    "response.error" => {
780                        error_incoming = true;
781                    }
782                    _ => {}
783                };
784                continue;
785            }
786            Some(ParsedResponseLine::Data(data)) => data,
787            None => {
788                continue;
789            }
790        };
791
792        if error_incoming
793            && let Some(response) = &response_output.response
794            && let Some(error) = &response.error
795        {
796            Err(chatbot_err!(
797                StreamingError,
798                format!("Error received from the API: {}.", error)
799            ))?
800        };
801
802        if response_received {
803            // the stream ended
804            if function_name_id_args.is_empty() {
805                return Err(anyhow::anyhow!(
806                    "The LLM response was supposed to contain function calls, but no function calls were found"
807                ));
808            }
809            if common_response_id.is_empty() {
810                return Err(anyhow::anyhow!(
811                    "Received tool response but response id not found, this shouldn't happen."
812                ));
813            };
814            let mut tool_msgs = Vec::new();
815
816            for (name, id, args) in function_name_id_args.iter() {
817                let tool = get_chatbot_tool(conn, name, args, user_context).await?;
818
819                tool_msgs.push(APIOutputMessage {
820                    message_type: OutputItem::FunctionCall {
821                        response_id: (common_response_id).to_owned(),
822                        call_id: id.to_owned(),
823                        tool_name: name.to_owned(),
824                        arguments: serde_json::to_string(tool.get_arguments())?,
825                    },
826                });
827                tool_msgs.push(APIOutputMessage {
828                    message_type: OutputItem::FunctionCallOutput {
829                        call_id: id.to_owned(),
830                        output: tool.get_tool_output(),
831                        response_id: (common_response_id).to_owned(),
832                    },
833                });
834            }
835            // save tool_msgs to the db
836            for m in &tool_msgs {
837                chatbot_conversation_messages::insert(
838                    conn,
839                    m.to_chatbot_conversation_message(conversation_id)?,
840                )
841                .await?;
842            }
843            messages.extend(tool_msgs);
844            break;
845        } else if let Some(item) = response_output.item {
846            match item {
847                OutputItem::FunctionCall {
848                    call_id,
849                    tool_name,
850                    arguments,
851                    response_id,
852                } => {
853                    common_response_id = response_id;
854                    function_name_id_args.push((
855                        tool_name,
856                        call_id,
857                        serde_json::from_str::<Value>(&arguments)?,
858                    ));
859                }
860                OutputItem::Message { .. } => Err(chatbot_err!(
861                    StreamingError,
862                    "Error: unexpected message item !!!".to_string()
863                ))?,
864                _ => {
865                    // save this chunk's data
866                    process_output_item(conn, item.clone(), conversation_id, app_config).await?;
867                    // add this output item to the messages to be included in the next
868                    // LLMRequest
869                    messages.push(APIOutputMessage { message_type: item });
870                }
871            }
872        }
873    }
874    Ok(messages)
875}
876
877/// Streams and parses a LLM response from Azure that contains a text response.
878pub async fn parse_and_stream_to_user<'a>(
879    conn: &mut PgConnection,
880    mut lines: PeekableLinesStream<'a>,
881    conversation_id: Uuid,
882    pool: PgPool,
883    request_estimated_tokens: i32,
884    response_id: String,
885    app_config: ApplicationConfiguration,
886) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send + 'a>>> {
887    // insert the to-be-streamed bot text response to db
888    let response_message = models::chatbot_conversation_messages::insert(
889        conn,
890        ChatbotConversationMessage {
891            conversation_id,
892            message: Message::Text(ChatbotConversationMessageMessage {
893                text: "".to_string(),
894                message_role: MessageRole::Assistant,
895                message_is_complete: false,
896                used_tokens: request_estimated_tokens,
897                response_id: Some(response_id.to_owned()),
898                ..Default::default()
899            }),
900            ..Default::default()
901        },
902    )
903    .await?;
904    models::chatbot_conversation_messages_citations::update_citation_message_ids(
905        conn,
906        response_id,
907        response_message.id,
908    )
909    .await?;
910
911    let done = Arc::new(AtomicBool::new(false));
912    let full_response_text = Arc::new(Mutex::new(Vec::new()));
913    // Instantiate the guard before creating the stream.
914    let guard = RequestCancelledGuard {
915        response_message_id: response_message.id,
916        received_string: full_response_text.clone(),
917        pool: pool.clone(),
918        done: done.clone(),
919        request_estimated_tokens,
920    };
921
922    trace!("Parsing stream to user...");
923
924    let mut response_received = false;
925    let mut error_incoming = false;
926
927    let response_stream = async_stream::try_stream! {
928        while let Some(val) = lines.next().await {
929            let line = val?;
930            let response_output: ResponseOutput = match ParsedResponseLine::parse(&line)? {
931                Some(ParsedResponseLine::Event(event_type)) => {
932                    match event_type.as_str() {
933                        "response.completed" | "response.incomplete" => {response_received = true;},
934                        "response.output_text.delta" => {
935                            // streaming
936                        },
937                        "response.function_call_arguments.delta" => {
938                            error!("ERROR, function call received but can't be processed while streaming to user.");
939                            return Err(chatbot_err!(StreamingError, format!("Unexpected function call while streaming to user")))?
940                        },
941                        "response.error" => {error_incoming = true;},
942                        _ => {},
943                    };
944                    continue;
945                },
946                Some(ParsedResponseLine::Data(data)) => data,
947                None => {continue;},
948            };
949
950            let mut full_response_text = full_response_text.lock().await;
951
952            if response_received {
953                let full_response_as_string = full_response_text.join("");
954                // todo: use the tokens given in the response
955                let estimated_cost = estimate_tokens(&full_response_as_string);
956                trace!(
957                    "End of chatbot response stream. Estimated cost: {}. Response: {}",
958                    estimated_cost, full_response_as_string
959                );
960                done.store(true, atomic::Ordering::Relaxed);
961                let mut conn = pool.acquire().await?;
962                models::chatbot_conversation_messages::update(
963                    &mut conn,
964                    response_message.id,
965                    &full_response_as_string,
966                    true,
967                    request_estimated_tokens + estimated_cost,
968                ).await?;
969                break;
970            }
971
972            if error_incoming &&
973                let Some(response) = &response_output.response && let Some(error) = &response.error
974            {
975                Err(chatbot_err!(StreamingError, format!("Error received from the API: {}.", error)))?
976
977            };
978
979            if let Some(delta) = &response_output.delta {
980                full_response_text.push(delta.to_owned());
981                let response = ChatResponse { text: delta.clone() };
982                let response_as_string = serde_json::to_string(&response)?;
983                yield Bytes::from(response_as_string);
984                yield Bytes::from("\n");
985            }
986
987            if let Some(item) = &response_output.item {
988                match item {
989                    OutputItem::Message { .. } => continue,
990                    OutputItem::FunctionCall { .. } => Err(chatbot_err!(StreamingError, "Error: unexpected function call after / during a text response.".to_string()))?,
991                    _ => {
992                        let mut conn = pool.acquire().await?;
993                        process_output_item(&mut conn, item.to_owned(), conversation_id, &app_config).await?;
994                        continue;
995                    },
996                };
997            }
998        }
999
1000        if !done.load(atomic::Ordering::Relaxed) {
1001            Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
1002        }
1003    };
1004
1005    // Encapsulate the stream and the guard within GuardedStream. This moves the request guard into the stream and ensures that it is dropped when the stream is dropped.
1006    // This way we do cleanup only when the stream is dropped and not when this function returns.
1007    let guarded_stream = GuardedStream::new(guard, response_stream);
1008
1009    // Box and pin the GuardedStream to satisfy the Unpin requirement
1010    Ok(Box::pin(guarded_stream))
1011}
1012
1013pub async fn send_chat_request_and_parse_stream(
1014    conn: &mut PgConnection,
1015    pool: PgPool,
1016    app_config: &ApplicationConfiguration,
1017    chatbot_configuration_id: Uuid,
1018    conversation_id: Uuid,
1019    message: &str,
1020    user_context: ChatbotUserContext,
1021) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
1022    let (mut chat_request, request_estimated_tokens) =
1023        LLMRequest::build_and_insert_incoming_message_to_db(
1024            conn,
1025            chatbot_configuration_id,
1026            conversation_id,
1027            message,
1028            app_config,
1029        )
1030        .await?;
1031
1032    let mut max_iterations_left = 15;
1033
1034    loop {
1035        max_iterations_left -= 1;
1036        if max_iterations_left == 0 {
1037            error!("Maximum tool call iterations exceeded");
1038            return Err(anyhow::anyhow!(
1039                "Maximum tool call iterations exceeded. The LLM may be stuck in a loop."
1040            ));
1041        }
1042
1043        let (response_id, response_type) =
1044            make_request_and_stream(conn, chat_request.clone(), conversation_id, app_config)
1045                .await?;
1046
1047        let new_conversation_items = match response_type {
1048            ResponseStreamType::Toolcall(stream) => {
1049                parse_tool(conn, stream, conversation_id, &user_context, app_config).await?
1050            }
1051            ResponseStreamType::TextResponse(stream) => {
1052                return parse_and_stream_to_user(
1053                    conn,
1054                    stream,
1055                    conversation_id,
1056                    pool,
1057                    request_estimated_tokens,
1058                    response_id,
1059                    app_config.to_owned(),
1060                )
1061                .await;
1062            }
1063        };
1064        chat_request.input.extend(
1065            new_conversation_items
1066                .into_iter()
1067                .map(APIInputMessage::try_from)
1068                .collect::<ChatbotResult<Vec<APIInputMessage>>>()?,
1069        );
1070    }
1071}