1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::{
4 Arc,
5 atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use anyhow::{Error, anyhow};
10use bytes::Bytes;
11use chrono::Utc;
12use futures::stream::{BoxStream, Peekable};
13use futures::{Stream, StreamExt, TryStreamExt};
14use headless_lms_base::config::ApplicationConfiguration;
15use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
16use headless_lms_models::chatbot_conversation_message_messages::{
17 ChatbotConversationMessageMessage, MessageRole,
18};
19use headless_lms_models::chatbot_conversation_messages::{
20 self, ChatbotConversationMessage, Message,
21};
22use pin_project::pin_project;
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use sqlx::PgPool;
26use tokio::{io::AsyncBufReadExt, sync::Mutex};
27use tokio_stream::wrappers::LinesStream;
28use tokio_util::io::StreamReader;
29use tracing::trace;
30use url::Url;
31
32use crate::chatbot_error::ChatbotResult;
33use crate::chatbot_tools::provider_tools::azure_ai_search::get_azure_ai_search_tool_definition;
34use crate::chatbot_tools::{
35 AzureLLMToolDefinition, ChatbotTool, get_chatbot_tool, get_chatbot_tool_definitions,
36};
37use crate::citations::chatbot_cited_documents_to_citations;
38use crate::llm_utils::{
39 APIInputMessage, APIOutputMessage, MessageContent, estimate_tokens, get_params_for_model,
40 make_streaming_llm_request,
41};
42
43use crate::prelude::*;
44
45pub const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
46
47enum ParsedResponseLine {
48 Event(String),
49 Data(ResponseOutput),
50}
51
52impl ParsedResponseLine {
53 pub fn parse(input: &str) -> ChatbotResult<Option<Self>> {
54 if input.starts_with("event: ") {
55 let event_type = input.trim_start_matches("event: ").to_string();
56 Ok(Some(ParsedResponseLine::Event(event_type)))
57 } else if input.starts_with("data: ") {
58 let data = input.trim_start_matches("data: ").to_string();
59 let response_output =
60 serde_json::from_str::<ResponseOutput>(&data).map_err(ChatbotError::from)?;
61 Ok(Some(ParsedResponseLine::Data(response_output)))
62 } else {
63 Ok(None)
64 }
65 }
66}
67
68pub struct ChatbotUserContext {
71 pub user_id: Uuid,
72 pub course_id: Uuid,
73 pub course_name: String,
74}
75
76#[derive(Deserialize, Serialize, Debug)]
77pub struct ContentFilterResults {
78 pub hate: Option<ContentFilter>,
79 pub self_harm: Option<ContentFilter>,
80 pub sexual: Option<ContentFilter>,
81 pub violence: Option<ContentFilter>,
82 }
84
85#[derive(Deserialize, Serialize, Debug)]
86pub struct ContentFilter {
87 pub blocked: bool,
88 pub source_type: ContentFilterSource,
89 pub content_filter_results: Vec<ContentFilterResults>,
90}
91#[derive(Deserialize, Serialize, Debug)]
92pub struct ContentFilterResult {
93 pub filtered: bool,
94 pub severity: String,
95}
96
97#[derive(Deserialize, Serialize, Debug)]
98#[serde(rename_all = "snake_case")]
99pub enum ContentFilterSource {
100 Prompt,
101 Completion,
102}
103
104#[derive(Deserialize, Serialize, Debug)]
106pub struct Response {
107 pub id: String,
108 pub error: Option<String>,
109}
110
111#[derive(Deserialize, Serialize, Debug)]
113pub struct IncompleteResponse {
114 pub id: String,
115 pub incomplete_details: IncompleteReason,
116 pub content_filters: Vec<ContentFilter>,
117}
118
119#[derive(Deserialize, Serialize, Debug)]
121pub struct IncompleteReason {
122 pub reason: String,
123}
124
125#[derive(Deserialize, Serialize, Debug)]
127pub struct ResponseOutput {
128 pub delta: Option<String>,
129 pub item: Option<OutputItem>,
130 pub response: Option<Response>,
131}
132
133#[derive(Deserialize, Serialize, Debug, Clone)]
134#[serde(tag = "type")]
135#[serde(rename_all = "snake_case")]
136pub enum OutputItem {
137 Message {
138 response_id: String,
139 role: MessageRole,
140 content: MessageContent,
141 },
142 Reasoning {
143 response_id: String,
144 summary: Vec<ReasoningOutput>,
145 },
146 AzureAiSearchCall {
147 response_id: String,
148 call_id: String,
149 arguments: String,
151 },
152 AzureAiSearchCallOutput {
153 response_id: String,
154 call_id: String,
155 output: String,
157 },
158 FunctionCall {
159 response_id: String,
160 call_id: String,
161 #[serde(rename = "name")]
162 tool_name: String,
163 arguments: String,
165 },
166 FunctionCallOutput {
167 response_id: String,
168 call_id: String,
169 output: String,
170 },
171}
172
173#[derive(Deserialize, Serialize, Debug, Clone)]
174#[serde(tag = "type")]
175#[serde(rename_all = "snake_case")]
176pub enum InputItem {
177 Message {
178 role: MessageRole,
179 content: MessageContent,
180 },
181 FunctionCall {
182 call_id: String,
183 #[serde(rename = "name")]
184 tool_name: String,
185 arguments: String,
186 },
187 FunctionCallOutput {
188 call_id: String,
189 output: String,
190 },
191}
192
193#[derive(Deserialize, Serialize, Debug, Clone)]
194pub struct AiSearchOutput {
195 pub get_urls: Vec<Url>,
196}
197
198#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
199#[serde(rename_all = "snake_case")]
200pub enum LLMToolChoice {
201 Auto,
202 None,
203}
204
205#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
206pub struct ThinkingParams {
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub reasoning: Option<Reasoning>,
209}
210
211#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
212pub struct RequestTextOptions {
213 #[serde(skip_serializing_if = "Option::is_none")]
214 pub verbosity: Option<VerbosityLevel>,
215 #[serde(skip_serializing_if = "Option::is_none")]
216 pub format: Option<LLMRequestResponseFormatParam>,
217}
218#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
219pub struct Reasoning {
220 pub effort: ReasoningEffortLevel,
221 pub summary: Option<SummaryType>,
223}
224
225#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
226#[serde(untagged)]
227pub enum SummaryType {
228 Concise,
229 Detailed,
230 Auto,
231}
232
233#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
234pub struct ReasoningOutput {
235 #[serde(rename = "type")]
236 pub output_type: String, pub text: String,
238}
239
240#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
241pub struct NonThinkingParams {
242 #[serde(skip_serializing_if = "Option::is_none")]
243 pub temperature: Option<f32>,
244 #[serde(skip_serializing_if = "Option::is_none")]
245 pub top_p: Option<f32>,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 pub frequency_penalty: Option<f32>,
248 #[serde(skip_serializing_if = "Option::is_none")]
249 pub presence_penalty: Option<f32>,
250}
251
252#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
253pub struct MistralParams {
254 pub test: bool,
256}
257
258#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
259#[serde(untagged)]
260pub enum LLMRequestParams {
261 GPTThinking(ThinkingParams),
262 GPTNonThinking(NonThinkingParams),
263 Mistral(MistralParams),
264}
265
266#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
267#[serde(rename_all = "snake_case")]
268pub enum JSONType {
269 JsonSchema,
270 Object,
271 Array,
272 String,
273}
274
275#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
277#[serde(rename_all = "camelCase")]
278pub struct Schema {
279 #[serde(rename = "type")]
280 pub type_field: JSONType,
282 pub properties: HashMap<String, ArrayProperty>,
284 pub required: Vec<String>,
286 pub additional_properties: bool,
288}
289
290#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
291pub struct ArrayProperty {
292 #[serde(rename = "type")]
293 pub type_field: JSONType,
294 pub items: ArrayItem,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
298pub struct ArrayItem {
299 #[serde(rename = "type")]
300 pub type_field: JSONType,
301}
302
303#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
304pub struct LLMRequestResponseFormatParam {
305 #[serde(rename = "type")]
306 pub format_type: JSONType, pub name: String,
308 pub schema: Schema,
309 pub strict: bool, }
311
312#[derive(Serialize, Deserialize, Debug, Clone)]
313pub struct LLMRequest {
314 pub input: Vec<APIInputMessage>,
315 pub model: String,
316 #[serde(skip_serializing_if = "Vec::is_empty", default)]
317 pub tools: Vec<AzureLLMToolDefinition>,
318 #[serde(skip_serializing_if = "Option::is_none")]
319 pub tool_choice: Option<LLMToolChoice>,
320 #[serde(skip_serializing_if = "Option::is_none")]
321 pub max_output_tokens: Option<i32>,
322 #[serde(skip_serializing_if = "Option::is_none")]
323 pub text: Option<RequestTextOptions>,
324 #[serde(flatten)]
325 pub params: LLMRequestParams,
326}
327
328impl LLMRequest {
329 pub async fn build_and_insert_incoming_message_to_db(
330 conn: &mut PgConnection,
331 chatbot_configuration_id: Uuid,
332 conversation_id: Uuid,
333 message: &str,
334 app_config: &ApplicationConfiguration,
335 ) -> anyhow::Result<(Self, i32)> {
336 let configuration =
337 models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
338
339 let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
340 conn,
341 chatbot_configuration_id,
342 )
343 .await?;
344
345 let conversation_messages =
346 models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
347 .await?;
348
349 let new_order_number = conversation_messages
350 .iter()
351 .map(|m| m.order_number)
352 .max()
353 .unwrap_or(0)
354 + 1;
355
356 let new_message = models::chatbot_conversation_messages::insert(
357 conn,
358 ChatbotConversationMessage {
359 id: Uuid::new_v4(),
360 order_number: new_order_number,
361 created_at: Utc::now(),
362 updated_at: Utc::now(),
363 deleted_at: None,
364 conversation_id,
365 message: Message::Text(ChatbotConversationMessageMessage {
366 text: message.to_string(),
367 message_role: MessageRole::User,
368 message_is_complete: true,
369 used_tokens: estimate_tokens(message),
370 ..Default::default()
371 }),
372 },
373 )
374 .await?;
375
376 let mut api_chat_messages: Vec<APIInputMessage> = conversation_messages
377 .into_iter()
378 .filter_map(|m| match m.message {
379 Message::Reasoning(..) => None,
380 _ => Some(APIInputMessage::try_from(m)),
381 })
382 .collect::<ChatbotResult<Vec<_>>>()?;
383
384 api_chat_messages.push(new_message.clone().try_into()?);
386
387 api_chat_messages.insert(
388 0,
389 APIInputMessage {
390 message_type: InputItem::Message {
391 role: MessageRole::System,
392 content: MessageContent::Text(configuration.prompt.clone()),
393 },
394 },
395 );
396
397 let mut tools = if configuration.use_tools {
398 get_chatbot_tool_definitions()
399 } else {
400 Vec::new()
401 };
402
403 if configuration.use_azure_search {
404 tools.extend(vec![AzureLLMToolDefinition::Search(
405 get_azure_ai_search_tool_definition(
406 app_config,
407 configuration.course_id,
408 configuration.use_semantic_reranking,
409 )?,
410 )]);
411 };
412
413 let tool_choice = if configuration.use_azure_search || configuration.use_tools {
414 Some(LLMToolChoice::Auto)
415 } else {
416 None
417 };
418
419 let serialized_messages = serde_json::to_string(&api_chat_messages)?;
420 let request_estimated_tokens = estimate_tokens(&serialized_messages);
421
422 let params = get_params_for_model(&model, &configuration);
423
424 Ok((
425 Self {
426 input: api_chat_messages,
427 model: model.model,
428 max_output_tokens: Some(configuration.max_output_tokens),
429 tools,
430 tool_choice,
431 text: Some(RequestTextOptions {
432 verbosity: Some(configuration.verbosity),
433 format: None,
434 }),
435 params,
436 },
437 request_estimated_tokens,
438 ))
439 }
440}
441
442#[derive(Serialize, Deserialize, Debug, Clone)]
443pub struct ChatResponse {
444 pub text: String,
445}
446
447#[pin_project]
449struct GuardedStream<S> {
450 guard: RequestCancelledGuard,
451 #[pin]
452 stream: S,
453}
454
455impl<S> GuardedStream<S> {
456 fn new(guard: RequestCancelledGuard, stream: S) -> Self {
457 Self { guard, stream }
458 }
459}
460
461impl<S> Stream for GuardedStream<S>
462where
463 S: Stream<Item = anyhow::Result<Bytes>> + Send,
464{
465 type Item = S::Item;
466
467 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
468 let this = self.project();
469 this.stream.poll_next(cx)
470 }
471}
472
473type PeekableLinesStream<'a> = Pin<
476 Box<Peekable<LinesStream<StreamReader<BoxStream<'a, Result<Bytes, std::io::Error>>, Bytes>>>>,
477>;
478pub enum ResponseStreamType<'a> {
479 Toolcall(PeekableLinesStream<'a>),
480 TextResponse(PeekableLinesStream<'a>),
481}
482
483struct RequestCancelledGuard {
484 response_message_id: Uuid,
485 received_string: Arc<Mutex<Vec<String>>>,
486 pool: PgPool,
487 done: Arc<AtomicBool>,
488 request_estimated_tokens: i32,
489}
490
491impl Drop for RequestCancelledGuard {
492 fn drop(&mut self) {
493 if self.done.load(atomic::Ordering::Relaxed) {
494 return;
495 }
496 warn!("Request was not cancelled. Cleaning up.");
497 let response_message_id = self.response_message_id;
498 let received_string = self.received_string.clone();
499 let pool = self.pool.clone();
500 let request_estimated_tokens = self.request_estimated_tokens;
501 tokio::spawn(async move {
502 info!("Verifying the received message has been handled");
503 let mut conn = pool.acquire().await.expect("Could not acquire connection");
504 let full_response_text = received_string.lock().await;
505 if full_response_text.is_empty() {
506 info!("No response received. Deleting the response message");
507 models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
508 .await
509 .expect("Could not delete response message");
510 return;
511 }
512 info!("Response received but not completed. Saving the text received so far.");
513 let full_response_as_string = full_response_text.join("");
514 let estimated_cost = estimate_tokens(&full_response_as_string);
515 info!(
516 "End of chatbot response stream. Estimated cost: {}. Response: {}",
517 estimated_cost, full_response_as_string
518 );
519
520 models::chatbot_conversation_message_messages::update(
522 &mut conn,
523 response_message_id,
524 &full_response_as_string,
525 true,
526 request_estimated_tokens + estimated_cost,
527 )
528 .await
529 .expect("Could not update response message");
530 });
531 }
532}
533
534pub async fn make_request_and_stream<'a>(
540 conn: &mut PgConnection,
541 chat_request: LLMRequest,
542 conversation_id: Uuid,
543 app_config: &ApplicationConfiguration,
544) -> anyhow::Result<(String, ResponseStreamType<'a>)> {
545 let response = make_streaming_llm_request(chat_request, app_config).await?;
546
547 trace!("Receiving chat response with {:?}", response.version());
548
549 if !response.status().is_success() {
550 let status = response.status();
551 let error_message = response.text().await?;
552 return Err(anyhow::anyhow!(
553 "Failed to send chat request. Status: {}. Error: {}",
554 status,
555 error_message
556 ));
557 }
558
559 let stream = response
560 .bytes_stream()
561 .map_err(std::io::Error::other)
562 .boxed();
563 let reader = StreamReader::new(stream);
564 let lines = reader.lines();
565 let lines_stream = LinesStream::new(lines);
566 let peekable_lines_stream = lines_stream.peekable();
567 let mut pinned_lines = Box::pin(peekable_lines_stream);
568
569 let mut response_id = "".to_string();
572 let mut output_item_incoming = false;
573 let mut response_created_incoming = false;
574 let mut error_incoming = false;
575
576 loop {
577 let line_res = pinned_lines.as_mut().peek().await;
578 match line_res {
579 None => {
580 break;
581 }
582 Some(Err(e)) => {
583 return Err(anyhow!(
584 "There was an error streaming response from Azure: {}. Response id: {}",
585 e,
586 response_id
587 ));
588 }
589 Some(Result::Ok(line)) => {
590 match ParsedResponseLine::parse(line)? {
591 Some(ParsedResponseLine::Event(event_type)) => {
592 trace!("Event: {event_type}");
593 match event_type.as_str() {
594 "response.created" => {
595 response_created_incoming = true;
596 }
597 "response.output_item.done" => {
598 output_item_incoming = true;
599 }
600 "response.function_call_arguments.delta" => {
601 if response_id.is_empty() {
602 return Err(anyhow::anyhow!(
603 "No response_id found! This should never happen!"
604 ));
605 }
606 return Ok((
607 response_id,
608 ResponseStreamType::Toolcall(pinned_lines),
609 ));
610 }
611 "response.output_text.delta" => {
612 return Ok((
613 response_id,
614 ResponseStreamType::TextResponse(pinned_lines),
615 ));
616 }
617 "response.error" => {
618 error_incoming = true;
619 }
620 _ => {}
621 }
622 }
623 Some(ParsedResponseLine::Data(response_output)) => {
624 if error_incoming
625 && let Some(response) = &response_output.response
626 && let Some(error) = &response.error
627 {
628 Err(chatbot_err!(
629 StreamingError,
630 format!(
631 "Error received from the API: {}. Response id: {}",
632 error, response.id
633 )
634 ))?
635 };
636 if response_created_incoming {
637 let res = response_output.response.ok_or(chatbot_err!(
638 DeserializationError,
639 "Expected response object"
640 ))?;
641 response_id = res.id;
642 response_created_incoming = false;
643 }
644 if output_item_incoming {
645 let item = response_output.item.ok_or(chatbot_err!(
646 DeserializationError,
647 "Expected response output item"
648 ))?;
649 process_output_item(conn, item, conversation_id, app_config).await?;
651 output_item_incoming = false;
652 }
653 }
654 None => {}
655 }
656 pinned_lines.next().await;
657 continue;
658 }
659 }
660 }
661 Err(Error::msg(format!(
662 "The response received from Azure ended unexpectedly. Response id: {response_id}"
663 )))
664}
665
666pub async fn process_output_item(
670 conn: &mut PgConnection,
671 item: OutputItem,
672 conversation_id: Uuid,
673 app_config: &ApplicationConfiguration,
674) -> ChatbotResult<ChatbotConversationMessage> {
675 match item {
676 OutputItem::AzureAiSearchCall { .. } | OutputItem::Reasoning { .. } => {
677 let message = APIOutputMessage { message_type: item }
678 .to_chatbot_conversation_message(conversation_id)?;
679
680 ChatbotResult::Ok(chatbot_conversation_messages::insert(conn, message).await?)
681 }
682 OutputItem::AzureAiSearchCallOutput {
683 call_id,
684 output,
685 response_id,
686 } => {
687 let search_output: AiSearchOutput = serde_json::from_str(&output)?;
688 let api_key = if let Some(azure_config) = &app_config.azure_configuration
689 && let Some(search_config) = &azure_config.search_config
690 {
691 &search_config.search_api_key
692 } else {
693 return ChatbotResult::Err(chatbot_err!(
694 Other,
695 "Azure search configuration not found, cannot process Azure AI search output item.".to_string()
696 ));
697 };
698 let get_urls = search_output.get_urls.to_owned();
699
700 let message = APIOutputMessage {
701 message_type: OutputItem::AzureAiSearchCallOutput {
702 call_id,
703 output,
704 response_id,
705 },
706 }
707 .to_chatbot_conversation_message(conversation_id)?;
708
709 let conversation_message = chatbot_conversation_messages::insert(conn, message).await?;
710
711 chatbot_cited_documents_to_citations(
712 conn,
713 app_config.test_chatbot,
714 get_urls,
715 api_key,
716 conversation_message.id,
717 conversation_id,
718 )
719 .await?;
720
721 ChatbotResult::Ok(conversation_message)
722 }
723 OutputItem::Message { .. } => {
724 Err(chatbot_err!(
726 StreamingError,
727 "Unexpected message output item, it should have been streamed.".to_string()
728 ))
729 }
730 OutputItem::FunctionCall { .. } => {
731 Err(chatbot_err!(
733 StreamingError,
734 "Unexpected function call output item, it should have been processed.".to_string()
735 ))
736 }
737 OutputItem::FunctionCallOutput { .. } => {
738 Err(chatbot_err!(
742 StreamingError,
743 "Unexpected function call output item, this shouldn't happen.".to_string()
744 ))
745 }
746 }
747}
748
749pub async fn parse_tool<'a>(
752 conn: &mut PgConnection,
753 mut lines: PeekableLinesStream<'a>,
754 conversation_id: Uuid,
755 user_context: &ChatbotUserContext,
756 app_config: &ApplicationConfiguration,
757) -> anyhow::Result<Vec<APIOutputMessage>> {
758 let mut function_name_id_args: Vec<(String, String, Value)> = vec![];
759 let mut messages = vec![];
760 let mut common_response_id = "".to_string();
761 let mut response_received = false;
762 let mut error_incoming = false;
763
764 trace!("Parsing tool calls...");
765
766 while let Some(val) = lines.next().await {
767 let line = val?;
768 let response_output = match ParsedResponseLine::parse(&line)? {
769 Some(ParsedResponseLine::Event(event_type)) => {
770 match event_type.as_str() {
771 "response.completed" => {
772 response_received = true;
773 }
774 "response.output_text.delta" => {
775 return Err(anyhow::anyhow!(
776 "Error: Received response text while parsing tool calls. Either the tool call parsing failed or the LLM responded in an unexpected way."
777 ));
778 }
779 "response.error" => {
780 error_incoming = true;
781 }
782 _ => {}
783 };
784 continue;
785 }
786 Some(ParsedResponseLine::Data(data)) => data,
787 None => {
788 continue;
789 }
790 };
791
792 if error_incoming
793 && let Some(response) = &response_output.response
794 && let Some(error) = &response.error
795 {
796 Err(chatbot_err!(
797 StreamingError,
798 format!("Error received from the API: {}.", error)
799 ))?
800 };
801
802 if response_received {
803 if function_name_id_args.is_empty() {
805 return Err(anyhow::anyhow!(
806 "The LLM response was supposed to contain function calls, but no function calls were found"
807 ));
808 }
809 if common_response_id.is_empty() {
810 return Err(anyhow::anyhow!(
811 "Received tool response but response id not found, this shouldn't happen."
812 ));
813 };
814 let mut tool_msgs = Vec::new();
815
816 for (name, id, args) in function_name_id_args.iter() {
817 let tool = get_chatbot_tool(conn, name, args, user_context).await?;
818
819 tool_msgs.push(APIOutputMessage {
820 message_type: OutputItem::FunctionCall {
821 response_id: (common_response_id).to_owned(),
822 call_id: id.to_owned(),
823 tool_name: name.to_owned(),
824 arguments: serde_json::to_string(tool.get_arguments())?,
825 },
826 });
827 tool_msgs.push(APIOutputMessage {
828 message_type: OutputItem::FunctionCallOutput {
829 call_id: id.to_owned(),
830 output: tool.get_tool_output(),
831 response_id: (common_response_id).to_owned(),
832 },
833 });
834 }
835 for m in &tool_msgs {
837 chatbot_conversation_messages::insert(
838 conn,
839 m.to_chatbot_conversation_message(conversation_id)?,
840 )
841 .await?;
842 }
843 messages.extend(tool_msgs);
844 break;
845 } else if let Some(item) = response_output.item {
846 match item {
847 OutputItem::FunctionCall {
848 call_id,
849 tool_name,
850 arguments,
851 response_id,
852 } => {
853 common_response_id = response_id;
854 function_name_id_args.push((
855 tool_name,
856 call_id,
857 serde_json::from_str::<Value>(&arguments)?,
858 ));
859 }
860 OutputItem::Message { .. } => Err(chatbot_err!(
861 StreamingError,
862 "Error: unexpected message item !!!".to_string()
863 ))?,
864 _ => {
865 process_output_item(conn, item.clone(), conversation_id, app_config).await?;
867 messages.push(APIOutputMessage { message_type: item });
870 }
871 }
872 }
873 }
874 Ok(messages)
875}
876
877pub async fn parse_and_stream_to_user<'a>(
879 conn: &mut PgConnection,
880 mut lines: PeekableLinesStream<'a>,
881 conversation_id: Uuid,
882 pool: PgPool,
883 request_estimated_tokens: i32,
884 response_id: String,
885 app_config: ApplicationConfiguration,
886) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send + 'a>>> {
887 let response_message = models::chatbot_conversation_messages::insert(
889 conn,
890 ChatbotConversationMessage {
891 conversation_id,
892 message: Message::Text(ChatbotConversationMessageMessage {
893 text: "".to_string(),
894 message_role: MessageRole::Assistant,
895 message_is_complete: false,
896 used_tokens: request_estimated_tokens,
897 response_id: Some(response_id.to_owned()),
898 ..Default::default()
899 }),
900 ..Default::default()
901 },
902 )
903 .await?;
904 models::chatbot_conversation_messages_citations::update_citation_message_ids(
905 conn,
906 response_id,
907 response_message.id,
908 )
909 .await?;
910
911 let done = Arc::new(AtomicBool::new(false));
912 let full_response_text = Arc::new(Mutex::new(Vec::new()));
913 let guard = RequestCancelledGuard {
915 response_message_id: response_message.id,
916 received_string: full_response_text.clone(),
917 pool: pool.clone(),
918 done: done.clone(),
919 request_estimated_tokens,
920 };
921
922 trace!("Parsing stream to user...");
923
924 let mut response_received = false;
925 let mut error_incoming = false;
926
927 let response_stream = async_stream::try_stream! {
928 while let Some(val) = lines.next().await {
929 let line = val?;
930 let response_output: ResponseOutput = match ParsedResponseLine::parse(&line)? {
931 Some(ParsedResponseLine::Event(event_type)) => {
932 match event_type.as_str() {
933 "response.completed" | "response.incomplete" => {response_received = true;},
934 "response.output_text.delta" => {
935 },
937 "response.function_call_arguments.delta" => {
938 error!("ERROR, function call received but can't be processed while streaming to user.");
939 return Err(chatbot_err!(StreamingError, format!("Unexpected function call while streaming to user")))?
940 },
941 "response.error" => {error_incoming = true;},
942 _ => {},
943 };
944 continue;
945 },
946 Some(ParsedResponseLine::Data(data)) => data,
947 None => {continue;},
948 };
949
950 let mut full_response_text = full_response_text.lock().await;
951
952 if response_received {
953 let full_response_as_string = full_response_text.join("");
954 let estimated_cost = estimate_tokens(&full_response_as_string);
956 trace!(
957 "End of chatbot response stream. Estimated cost: {}. Response: {}",
958 estimated_cost, full_response_as_string
959 );
960 done.store(true, atomic::Ordering::Relaxed);
961 let mut conn = pool.acquire().await?;
962 models::chatbot_conversation_messages::update(
963 &mut conn,
964 response_message.id,
965 &full_response_as_string,
966 true,
967 request_estimated_tokens + estimated_cost,
968 ).await?;
969 break;
970 }
971
972 if error_incoming &&
973 let Some(response) = &response_output.response && let Some(error) = &response.error
974 {
975 Err(chatbot_err!(StreamingError, format!("Error received from the API: {}.", error)))?
976
977 };
978
979 if let Some(delta) = &response_output.delta {
980 full_response_text.push(delta.to_owned());
981 let response = ChatResponse { text: delta.clone() };
982 let response_as_string = serde_json::to_string(&response)?;
983 yield Bytes::from(response_as_string);
984 yield Bytes::from("\n");
985 }
986
987 if let Some(item) = &response_output.item {
988 match item {
989 OutputItem::Message { .. } => continue,
990 OutputItem::FunctionCall { .. } => Err(chatbot_err!(StreamingError, "Error: unexpected function call after / during a text response.".to_string()))?,
991 _ => {
992 let mut conn = pool.acquire().await?;
993 process_output_item(&mut conn, item.to_owned(), conversation_id, &app_config).await?;
994 continue;
995 },
996 };
997 }
998 }
999
1000 if !done.load(atomic::Ordering::Relaxed) {
1001 Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
1002 }
1003 };
1004
1005 let guarded_stream = GuardedStream::new(guard, response_stream);
1008
1009 Ok(Box::pin(guarded_stream))
1011}
1012
1013pub async fn send_chat_request_and_parse_stream(
1014 conn: &mut PgConnection,
1015 pool: PgPool,
1016 app_config: &ApplicationConfiguration,
1017 chatbot_configuration_id: Uuid,
1018 conversation_id: Uuid,
1019 message: &str,
1020 user_context: ChatbotUserContext,
1021) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
1022 let (mut chat_request, request_estimated_tokens) =
1023 LLMRequest::build_and_insert_incoming_message_to_db(
1024 conn,
1025 chatbot_configuration_id,
1026 conversation_id,
1027 message,
1028 app_config,
1029 )
1030 .await?;
1031
1032 let mut max_iterations_left = 15;
1033
1034 loop {
1035 max_iterations_left -= 1;
1036 if max_iterations_left == 0 {
1037 error!("Maximum tool call iterations exceeded");
1038 return Err(anyhow::anyhow!(
1039 "Maximum tool call iterations exceeded. The LLM may be stuck in a loop."
1040 ));
1041 }
1042
1043 let (response_id, response_type) =
1044 make_request_and_stream(conn, chat_request.clone(), conversation_id, app_config)
1045 .await?;
1046
1047 let new_conversation_items = match response_type {
1048 ResponseStreamType::Toolcall(stream) => {
1049 parse_tool(conn, stream, conversation_id, &user_context, app_config).await?
1050 }
1051 ResponseStreamType::TextResponse(stream) => {
1052 return parse_and_stream_to_user(
1053 conn,
1054 stream,
1055 conversation_id,
1056 pool,
1057 request_estimated_tokens,
1058 response_id,
1059 app_config.to_owned(),
1060 )
1061 .await;
1062 }
1063 };
1064 chat_request.input.extend(
1065 new_conversation_items
1066 .into_iter()
1067 .map(APIInputMessage::try_from)
1068 .collect::<ChatbotResult<Vec<APIInputMessage>>>()?,
1069 );
1070 }
1071}