1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::sync::{
5 Arc,
6 atomic::{self, AtomicBool},
7};
8use std::task::{Context, Poll};
9
10use anyhow::{Error, Ok, anyhow};
11use bytes::Bytes;
12use chrono::Utc;
13use futures::stream::{BoxStream, Peekable};
14use futures::{Stream, StreamExt, TryStreamExt};
15use headless_lms_base::config::ApplicationConfiguration;
16use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
17use headless_lms_models::chatbot_conversation_messages::{
18 self, ChatbotConversationMessage, MessageRole,
19};
20use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
21use pin_project::pin_project;
22use serde::{Deserialize, Serialize};
23use sqlx::PgPool;
24use tokio::{io::AsyncBufReadExt, sync::Mutex};
25use tokio_stream::wrappers::LinesStream;
26use tokio_util::io::StreamReader;
27use tracing::trace;
28use url::Url;
29
30use crate::chatbot_error::ChatbotResult;
31use crate::chatbot_tools::{
32 AzureLLMToolDefinition, ChatbotTool, get_chatbot_tool, get_chatbot_tool_definitions,
33};
34use crate::llm_utils::{
35 APIMessage, APIMessageKind, APIMessageText, APIMessageToolCall, APIMessageToolResponse,
36 APITool, APIToolCall, estimate_tokens, make_streaming_llm_request,
37};
38use headless_lms_utils::strings::truncate_utf8_at_boundary;
39use headless_lms_utils::url_encoding::url_decode;
40
41use crate::prelude::*;
42use crate::search_filter::SearchFilter;
43
44const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
45
46pub struct ChatbotUserContext {
49 pub user_id: Uuid,
50 pub course_id: Uuid,
51 pub course_name: String,
52}
53
54#[derive(Deserialize, Serialize, Debug)]
55pub struct ContentFilterResults {
56 pub hate: Option<ContentFilter>,
57 pub self_harm: Option<ContentFilter>,
58 pub sexual: Option<ContentFilter>,
59 pub violence: Option<ContentFilter>,
60}
61
62#[derive(Deserialize, Serialize, Debug)]
63pub struct ContentFilter {
64 pub filtered: bool,
65 pub severity: String,
66}
67
68#[derive(Deserialize, Serialize, Debug)]
70pub struct Choice {
71 pub content_filter_results: Option<ContentFilterResults>,
72 pub delta: Option<Delta>,
73 pub finish_reason: Option<String>,
74 pub index: i32,
75}
76
77#[derive(Deserialize, Serialize, Debug)]
79pub struct Delta {
80 pub content: Option<String>,
81 pub context: Option<DeltaContext>,
82 pub tool_calls: Option<Vec<ToolCallInDelta>>,
83}
84
85#[derive(Deserialize, Serialize, Debug)]
86pub struct DeltaContext {
87 pub citations: Vec<Citation>,
88}
89
90#[derive(Deserialize, Serialize, Debug)]
92pub struct ToolCallInDelta {
93 pub id: Option<String>,
94 pub function: DeltaTool,
95 #[serde(rename = "type")]
96 pub tool_type: Option<ToolCallType>,
97}
98
99#[derive(Deserialize, Serialize, Debug, Clone)]
101pub struct DeltaTool {
102 #[serde(default)]
103 pub arguments: String,
104 pub name: Option<String>,
105}
106
107#[derive(Serialize, Deserialize, Debug, Clone)]
108#[serde(rename_all = "snake_case")]
109pub enum ToolCallType {
110 Function,
111}
112
113#[derive(Deserialize, Serialize, Debug)]
114pub struct Citation {
115 pub content: String,
116 pub title: String,
117 pub url: String,
118 pub filepath: String,
119}
120
121#[derive(Deserialize, Serialize, Debug)]
123pub struct ResponseChunk {
124 pub choices: Vec<Choice>,
125 pub created: u64,
126 pub id: String,
127 pub model: String,
128 pub object: String,
129 pub system_fingerprint: Option<String>,
130}
131
132#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
133#[serde(rename_all = "snake_case")]
134pub enum LLMToolChoice {
135 Auto,
136}
137
138#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
139pub struct ThinkingParams {
140 pub max_completion_tokens: Option<i32>,
141 pub verbosity: Option<VerbosityLevel>,
142 pub reasoning_effort: Option<ReasoningEffortLevel>,
143 #[serde(skip_serializing_if = "Vec::is_empty")]
144 pub tools: Vec<AzureLLMToolDefinition>,
145 pub tool_choice: Option<LLMToolChoice>,
146}
147
148#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
149pub struct NonThinkingParams {
150 pub max_tokens: Option<i32>,
151 pub temperature: Option<f32>,
152 pub top_p: Option<f32>,
153 pub frequency_penalty: Option<f32>,
154 pub presence_penalty: Option<f32>,
155}
156
157#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
158#[serde(untagged)]
159pub enum LLMRequestParams {
160 Thinking(ThinkingParams),
161 NonThinking(NonThinkingParams),
162}
163
164#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
165#[serde(rename_all = "snake_case")]
166pub enum JSONType {
167 JsonSchema,
168 Object,
169 Array,
170 String,
171}
172
173#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
175pub struct JSONSchema {
176 pub name: String,
177 pub strict: bool,
178 pub schema: Schema,
179}
180
181#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
183#[serde(rename_all = "camelCase")]
184pub struct Schema {
185 #[serde(rename = "type")]
186 pub type_field: JSONType,
188 pub properties: HashMap<String, ArrayProperty>,
190 pub required: Vec<String>,
192 pub additional_properties: bool,
194}
195
196#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
197pub struct ArrayProperty {
198 #[serde(rename = "type")]
199 pub type_field: JSONType,
200 pub items: ArrayItem,
201}
202
203#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
204pub struct ArrayItem {
205 #[serde(rename = "type")]
206 pub type_field: JSONType,
207}
208
209#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
210pub struct LLMRequestResponseFormatParam {
211 #[serde(rename = "type")]
212 pub format_type: JSONType, pub json_schema: JSONSchema,
214}
215
216#[derive(Serialize, Deserialize, Debug, Clone)]
217pub struct LLMRequest {
218 pub messages: Vec<APIMessage>,
219 #[serde(skip_serializing_if = "Vec::is_empty", default)]
220 pub data_sources: Vec<DataSource>,
221 #[serde(flatten)]
222 pub params: LLMRequestParams,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 pub response_format: Option<LLMRequestResponseFormatParam>,
225 pub stop: Option<String>,
226}
227
228impl LLMRequest {
229 pub async fn build_and_insert_incoming_message_to_db(
230 conn: &mut PgConnection,
231 chatbot_configuration_id: Uuid,
232 conversation_id: Uuid,
233 message: &str,
234 app_config: &ApplicationConfiguration,
235 ) -> anyhow::Result<(Self, i32, i32)> {
236 let index_name = Url::parse(&app_config.base_url)?
237 .host_str()
238 .expect("BASE_URL must have a host")
239 .replace(".", "-");
240
241 let configuration =
242 models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
243
244 let conversation_messages =
245 models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
246 .await?;
247
248 let new_order_number = conversation_messages
249 .iter()
250 .map(|m| m.order_number)
251 .max()
252 .unwrap_or(0)
253 + 1;
254
255 let new_message = models::chatbot_conversation_messages::insert(
256 conn,
257 ChatbotConversationMessage {
258 id: Uuid::new_v4(),
259 created_at: Utc::now(),
260 updated_at: Utc::now(),
261 deleted_at: None,
262 conversation_id,
263 message: Some(message.to_string()),
264 message_role: MessageRole::User,
265 message_is_complete: true,
266 used_tokens: estimate_tokens(message),
267 order_number: new_order_number,
268 tool_call_fields: vec![],
269 tool_output: None,
270 },
271 )
272 .await?;
273
274 let mut api_chat_messages: Vec<APIMessage> = conversation_messages
275 .into_iter()
276 .map(APIMessage::try_from)
277 .collect::<ChatbotResult<Vec<_>>>()?;
278
279 api_chat_messages.push(new_message.clone().try_into()?);
281
282 api_chat_messages.insert(
283 0,
284 APIMessage {
285 role: MessageRole::System,
286 fields: APIMessageKind::Text(APIMessageText {
287 content: configuration.prompt.clone(),
288 }),
289 },
290 );
291
292 let data_sources = if configuration.use_azure_search {
293 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
294 anyhow::anyhow!("Azure configuration is missing from the application configuration")
295 })?;
296
297 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
298 anyhow::anyhow!(
299 "Azure search configuration is missing from the Azure configuration"
300 )
301 })?;
302
303 let query_type = if configuration.use_semantic_reranking {
304 "vector_semantic_hybrid"
305 } else {
306 "vector_simple_hybrid"
307 };
308
309 api_chat_messages = api_chat_messages
314 .into_iter()
315 .filter(|m| !matches!(m.fields, APIMessageKind::ToolCall(_)))
316 .map(|m| match m.fields {
317 APIMessageKind::ToolResponse(r) => APIMessage {
318 role: MessageRole::Assistant,
319 fields: APIMessageKind::Text(APIMessageText { content: r.content }),
320 },
321 _ => m,
322 })
323 .collect();
324
325 vec![DataSource {
326 data_type: "azure_search".to_string(),
327 parameters: DataSourceParameters {
328 endpoint: search_config.search_endpoint.to_string(),
329 authentication: DataSourceParametersAuthentication {
330 auth_type: "api_key".to_string(),
331 key: search_config.search_api_key.clone(),
332 },
333 semantic_configuration: format!("{}-semantic-configuration", &index_name),
334 index_name,
335 query_type: query_type.to_string(),
336 embedding_dependency: EmbeddingDependency {
337 dep_type: "deployment_name".to_string(),
338 deployment_name: search_config.vectorizer_deployment_id.clone(),
339 },
340 in_scope: false,
341 top_n_documents: 15,
342 strictness: 3,
343 filter: Some(
344 SearchFilter::eq("course_id", configuration.course_id.to_string())
345 .to_odata()?,
346 ),
347 fields_mapping: FieldsMapping {
348 content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
349 content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
350 filepath_field: "filepath".to_string(),
351 title_field: "title".to_string(),
352 url_field: "url".to_string(),
353 vector_fields: vec!["text_vector".to_string()],
354 },
355 },
356 }]
357 } else {
358 Vec::new()
359 };
360
361 let tools = if configuration.use_tools {
362 get_chatbot_tool_definitions()
363 } else {
364 Vec::new()
365 };
366
367 let serialized_messages = serde_json::to_string(&api_chat_messages)?;
368 let request_estimated_tokens = estimate_tokens(&serialized_messages);
369
370 let params = if configuration.thinking_model {
371 LLMRequestParams::Thinking(ThinkingParams {
372 max_completion_tokens: Some(configuration.max_completion_tokens),
373 reasoning_effort: Some(configuration.reasoning_effort),
374 verbosity: Some(configuration.verbosity),
375 tools,
376 tool_choice: if configuration.use_tools {
377 Some(LLMToolChoice::Auto)
378 } else {
379 None
380 },
381 })
382 } else {
383 LLMRequestParams::NonThinking(NonThinkingParams {
384 max_tokens: Some(configuration.response_max_tokens),
385 temperature: Some(configuration.temperature),
386 top_p: Some(configuration.top_p),
387 frequency_penalty: Some(configuration.frequency_penalty),
388 presence_penalty: Some(configuration.presence_penalty),
389 })
390 };
391
392 Ok((
393 Self {
394 messages: api_chat_messages,
395 data_sources,
396 params,
397 response_format: None,
398 stop: None,
399 },
400 new_message.order_number,
401 request_estimated_tokens,
402 ))
403 }
404
405 pub async fn update_messages_to_db(
406 mut self,
407 conn: &mut PgConnection,
408 new_msgs: Vec<APIMessage>,
409 conversation_id: Uuid,
410 mut order_number: i32,
411 ) -> anyhow::Result<(Self, i32)> {
412 for m in new_msgs {
413 let converted_msg = m.to_chatbot_conversation_message(conversation_id, order_number)?;
414 chatbot_conversation_messages::insert(conn, converted_msg).await?;
415 self.messages.push(m);
416 order_number += 1;
417 }
418 Ok((self, order_number))
419 }
420}
421
422#[derive(Serialize, Deserialize, Debug, Clone)]
423pub struct DataSource {
424 #[serde(rename = "type")]
425 pub data_type: String,
426 pub parameters: DataSourceParameters,
427}
428
429#[derive(Serialize, Deserialize, Debug, Clone)]
430pub struct DataSourceParameters {
431 pub endpoint: String,
432 pub authentication: DataSourceParametersAuthentication,
433 pub index_name: String,
434 pub query_type: String,
435 pub embedding_dependency: EmbeddingDependency,
436 pub in_scope: bool,
437 pub top_n_documents: i32,
438 pub strictness: i32,
439 #[serde(skip_serializing_if = "Option::is_none")]
440 pub filter: Option<String>,
441 pub fields_mapping: FieldsMapping,
442 pub semantic_configuration: String,
443}
444
445#[derive(Serialize, Deserialize, Debug, Clone)]
446pub struct DataSourceParametersAuthentication {
447 #[serde(rename = "type")]
448 pub auth_type: String,
449 pub key: String,
450}
451
452#[derive(Serialize, Deserialize, Debug, Clone)]
453pub struct EmbeddingDependency {
454 #[serde(rename = "type")]
455 pub dep_type: String,
456 pub deployment_name: String,
457}
458
459#[derive(Serialize, Deserialize, Debug, Clone)]
460pub struct FieldsMapping {
461 pub content_fields_separator: String,
462 pub content_fields: Vec<String>,
463 pub filepath_field: String,
464 pub title_field: String,
465 pub url_field: String,
466 pub vector_fields: Vec<String>,
467}
468
469#[derive(Serialize, Deserialize, Debug)]
470pub struct ChatResponse {
471 pub text: String,
472}
473
474#[pin_project]
476struct GuardedStream<S> {
477 guard: RequestCancelledGuard,
478 #[pin]
479 stream: S,
480}
481
482impl<S> GuardedStream<S> {
483 fn new(guard: RequestCancelledGuard, stream: S) -> Self {
484 Self { guard, stream }
485 }
486}
487
488impl<S> Stream for GuardedStream<S>
489where
490 S: Stream<Item = anyhow::Result<Bytes>> + Send,
491{
492 type Item = S::Item;
493
494 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
495 let this = self.project();
496 this.stream.poll_next(cx)
497 }
498}
499
500type PeekableLinesStream<'a> = Pin<
503 Box<Peekable<LinesStream<StreamReader<BoxStream<'a, Result<Bytes, std::io::Error>>, Bytes>>>>,
504>;
505pub enum ResponseStreamType<'a> {
506 Toolcall(PeekableLinesStream<'a>),
507 TextResponse(PeekableLinesStream<'a>),
508}
509
510struct RequestCancelledGuard {
511 response_message_id: Uuid,
512 received_string: Arc<Mutex<Vec<String>>>,
513 pool: PgPool,
514 done: Arc<AtomicBool>,
515 request_estimated_tokens: i32,
516}
517
518impl Drop for RequestCancelledGuard {
519 fn drop(&mut self) {
520 if self.done.load(atomic::Ordering::Relaxed) {
521 return;
522 }
523 warn!("Request was not cancelled. Cleaning up.");
524 let response_message_id = self.response_message_id;
525 let received_string = self.received_string.clone();
526 let pool = self.pool.clone();
527 let request_estimated_tokens = self.request_estimated_tokens;
528 tokio::spawn(async move {
529 info!("Verifying the received message has been handled");
530 let mut conn = pool.acquire().await.expect("Could not acquire connection");
531 let full_response_text = received_string.lock().await;
532 if full_response_text.is_empty() {
533 info!("No response received. Deleting the response message");
534 models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
535 .await
536 .expect("Could not delete response message");
537 return;
538 }
539 info!("Response received but not completed. Saving the text received so far.");
540 let full_response_as_string = full_response_text.join("");
541 let estimated_cost = estimate_tokens(&full_response_as_string);
542 info!(
543 "End of chatbot response stream. Estimated cost: {}. Response: {}",
544 estimated_cost, full_response_as_string
545 );
546
547 models::chatbot_conversation_messages::update(
549 &mut conn,
550 response_message_id,
551 &full_response_as_string,
552 true,
553 request_estimated_tokens + estimated_cost,
554 )
555 .await
556 .expect("Could not update response message");
557 });
558 }
559}
560
561pub async fn make_request_and_stream<'a>(
562 chat_request: LLMRequest,
563 model_name: &str,
564 app_config: &ApplicationConfiguration,
565) -> anyhow::Result<ResponseStreamType<'a>> {
566 let response = make_streaming_llm_request(chat_request, model_name, app_config).await?;
567
568 trace!("Receiving chat response with {:?}", response.version());
569
570 if !response.status().is_success() {
571 let status = response.status();
572 let error_message = response.text().await?;
573 return Err(anyhow::anyhow!(
574 "Failed to send chat request. Status: {}. Error: {}",
575 status,
576 error_message
577 ));
578 }
579
580 let stream = response
581 .bytes_stream()
582 .map_err(std::io::Error::other)
583 .boxed();
584 let reader = StreamReader::new(stream);
585 let lines = reader.lines();
586 let lines_stream = LinesStream::new(lines);
587 let peekable_lines_stream = lines_stream.peekable();
588 let mut pinned_lines = Box::pin(peekable_lines_stream);
589
590 loop {
591 let line_res = pinned_lines.as_mut().peek().await;
592 match line_res {
593 None => {
594 break;
595 }
596 Some(Err(e)) => {
597 return Err(anyhow!(
598 "There was an error streaming response from Azure: {}",
599 e
600 ));
601 }
602 Some(Result::Ok(line)) => {
603 if !line.starts_with("data: ") {
604 pinned_lines.next().await;
605 continue;
606 }
607 let json_str = line.trim_start_matches("data: ");
608 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
609 .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {}", e))?;
610 for choice in &response_chunk.choices {
611 if let Some(d) = &choice.delta {
612 if d.content.is_some() || d.context.is_some() {
613 return Ok(ResponseStreamType::TextResponse(pinned_lines));
614 } else if let Some(_calls) = &d.tool_calls {
615 return Ok(ResponseStreamType::Toolcall(pinned_lines));
616 } else if d.content.is_none() {
617 pinned_lines.next().await;
618 continue;
619 }
620 }
621 }
622 pinned_lines.next().await;
623 }
624 }
625 }
626 Err(Error::msg(
627 "The response received from Azure had an unexpected shape and couldn't be parsed"
628 .to_string(),
629 ))
630}
631
632pub async fn parse_tool<'a>(
635 conn: &mut PgConnection,
636 mut lines: PeekableLinesStream<'a>,
637 user_context: &ChatbotUserContext,
638) -> anyhow::Result<Vec<APIMessage>> {
639 let mut function_name_id_args: Vec<(String, String, String)> = vec![];
640 let mut currently_streamed_function_name_id: Option<(String, String)> = None;
641 let mut currently_streamed_function_args = vec![];
642 let mut messages = vec![];
643
644 trace!("Parsing tool calls...");
645
646 while let Some(val) = lines.next().await {
647 let line = val?;
648 if !line.to_owned().starts_with("data: ") {
649 continue;
650 }
651 let json_str = line.trim_start_matches("data: ");
652 if json_str.trim() == "[DONE]" {
653 if function_name_id_args.is_empty() {
655 return Err(anyhow::anyhow!(
656 "The LLM response was supposed to contain function calls, but no function calls were found"
657 ));
658 }
659 let mut assistant_tool_calls = Vec::new();
660 let mut tool_result_msgs = Vec::new();
661
662 for (name, id, args) in function_name_id_args.iter() {
663 let tool = get_chatbot_tool(conn, name, args, user_context).await?;
664
665 assistant_tool_calls.push(APIToolCall {
666 function: APITool {
667 name: name.to_owned(),
668 arguments: serde_json::to_string(tool.get_arguments())?,
669 },
670 id: id.to_owned(),
671 tool_type: ToolCallType::Function,
672 });
673 tool_result_msgs.push(APIMessage {
674 role: MessageRole::Tool,
675 fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
676 content: tool.get_tool_output(),
677 name: name.to_owned(),
678 tool_call_id: id.to_owned(),
679 }),
680 })
681 }
682 messages.push(APIMessage {
684 role: MessageRole::Assistant,
685 fields: APIMessageKind::ToolCall(APIMessageToolCall {
686 tool_calls: assistant_tool_calls,
687 }),
688 });
689 messages.extend(tool_result_msgs);
691 break;
692 }
693 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
694 .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {} {}", e, json_str))?;
695 for choice in &response_chunk.choices {
696 if Some("tool_calls".to_string()) == choice.finish_reason {
697 if let Some((name, id)) = ¤tly_streamed_function_name_id {
700 let fn_args = currently_streamed_function_args.join("");
703 function_name_id_args.push((
704 name.to_owned(),
705 id.to_owned(),
706 fn_args.to_owned(),
707 ));
708 currently_streamed_function_args.clear();
709 currently_streamed_function_name_id = None;
710 }
713 }
714 if let Some(delta) = &choice.delta
715 && let Some(tool_calls) = &delta.tool_calls
716 {
717 for call in tool_calls {
719 if let (Some(name), Some(id)) = (&call.function.name, &call.id) {
720 if let Some((name_prev, id_prev)) = currently_streamed_function_name_id {
725 let fn_args = currently_streamed_function_args.join("");
726 function_name_id_args.push((
727 name_prev.to_owned(),
728 id_prev.to_owned(),
729 fn_args,
730 ));
731 currently_streamed_function_args.clear();
732 }
733 currently_streamed_function_name_id =
737 Some((name.to_owned(), id.to_owned()));
738 };
739 currently_streamed_function_args.push(call.function.arguments.clone());
742 }
743 }
744 }
745 }
746 Ok(messages)
747}
748
749pub async fn parse_and_stream_to_user<'a>(
751 conn: &mut PgConnection,
752 mut lines: PeekableLinesStream<'a>,
753 conversation_id: Uuid,
754 response_order_number: i32,
755 pool: PgPool,
756 request_estimated_tokens: i32,
757) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send + 'a>>> {
758 let response_message = models::chatbot_conversation_messages::insert(
760 conn,
761 ChatbotConversationMessage {
762 id: Uuid::new_v4(),
763 created_at: Utc::now(),
764 updated_at: Utc::now(),
765 deleted_at: None,
766 conversation_id,
767 message: Some("".to_string()),
768 message_role: MessageRole::Assistant,
769 message_is_complete: false,
770 used_tokens: request_estimated_tokens,
771 order_number: response_order_number,
772 tool_call_fields: vec![],
773 tool_output: None,
774 },
775 )
776 .await?;
777
778 let done = Arc::new(AtomicBool::new(false));
779 let full_response_text = Arc::new(Mutex::new(Vec::new()));
780 let guard = RequestCancelledGuard {
782 response_message_id: response_message.id,
783 received_string: full_response_text.clone(),
784 pool: pool.clone(),
785 done: done.clone(),
786 request_estimated_tokens,
787 };
788
789 trace!("Parsing stream to user...");
790
791 let response_stream = async_stream::try_stream! {
792 while let Some(val) = lines.next().await {
793 let line = val?;
794 if !line.starts_with("data: ") {
795 continue;
796 }
797 let mut full_response_text = full_response_text.lock().await;
798 let json_str = line.trim_start_matches("data: ");
799 if json_str.trim() == "[DONE]" {
800 let full_response_as_string = full_response_text.join("");
801 let estimated_cost = estimate_tokens(&full_response_as_string);
802 trace!(
803 "End of chatbot response stream. Estimated cost: {}. Response: {}",
804 estimated_cost, full_response_as_string
805 );
806 done.store(true, atomic::Ordering::Relaxed);
807 let mut conn = pool.acquire().await?;
808 models::chatbot_conversation_messages::update(
809 &mut conn,
810 response_message.id,
811 &full_response_as_string,
812 true,
813 request_estimated_tokens + estimated_cost,
814 ).await?;
815 break;
816 }
817 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
818 anyhow::anyhow!("Failed to parse response chunk: {}", e)
819 })?;
820
821 for choice in &response_chunk.choices {
822 if let Some(delta) = &choice.delta {
823 if let Some(content) = &delta.content {
824 full_response_text.push(content.clone());
825 let response = ChatResponse { text: content.clone() };
826 let response_as_string = serde_json::to_string(&response)?;
827 yield Bytes::from(response_as_string);
828 yield Bytes::from("\n");
829 }
830 if let Some(context) = &delta.context {
831 let mut conn = pool.acquire().await?;
832 for (idx, cit) in context.citations.iter().enumerate() {
833 let content = truncate_utf8_at_boundary(&cit.content, 255).to_string();
834 let split = content.split_once(CONTENT_FIELD_SEPARATOR);
835 if split.is_none() {
836 error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
837 }
838 let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
839
840 let decoded_title = url_decode(&cit.title)?;
844 let decoded_url = url_decode(&cit.url)?;
845
846 let mut page_path = PathBuf::from(&cit.filepath);
847 page_path.set_extension("");
848 let page_id_str = page_path.file_name();
849 let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
850 let course_material_chapter_number = if let Some(id) = page_id {
851 let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
852 chapter.map(|c| c.chapter_number)
853 } else {
854 None
855 };
856
857 models::chatbot_conversation_messages_citations::insert(
858 &mut conn, ChatbotConversationMessageCitation {
859 id: Uuid::new_v4(),
860 created_at: Utc::now(),
861 updated_at: Utc::now(),
862 deleted_at: None,
863 conversation_message_id: response_message.id,
864 conversation_id: response_message.conversation_id,
865 course_material_chapter_number,
866 title: decoded_title,
867 content: cleaned_content,
868 document_url: decoded_url,
869 citation_number: (idx+1) as i32,
870 }
871 ).await?;
872 }
873 }
874 }
875 }
876 }
877
878 if !done.load(atomic::Ordering::Relaxed) {
879 Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
880 }
881 };
882
883 let guarded_stream = GuardedStream::new(guard, response_stream);
886
887 Ok(Box::pin(guarded_stream))
889}
890
891pub async fn send_chat_request_and_parse_stream(
892 conn: &mut PgConnection,
893 pool: PgPool,
894 app_config: &ApplicationConfiguration,
895 chatbot_configuration_id: Uuid,
896 conversation_id: Uuid,
897 message: &str,
898 user_context: ChatbotUserContext,
899) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
900 let (mut chat_request, new_message_order_number, request_estimated_tokens) =
901 LLMRequest::build_and_insert_incoming_message_to_db(
902 conn,
903 chatbot_configuration_id,
904 conversation_id,
905 message,
906 app_config,
907 )
908 .await?;
909
910 let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
911 conn,
912 chatbot_configuration_id,
913 )
914 .await?;
915
916 let mut next_message_order_number = new_message_order_number + 1;
917 let mut max_iterations_left = 15;
918
919 loop {
920 max_iterations_left -= 1;
921 if max_iterations_left == 0 {
922 error!("Maximum tool call iterations exceeded");
923 return Err(anyhow::anyhow!(
924 "Maximum tool call iterations exceeded. The LLM may be stuck in a loop."
925 ));
926 }
927
928 let response_type =
929 make_request_and_stream(chat_request.clone(), &model.deployment_name, app_config)
930 .await?;
931
932 let new_tool_msgs = match response_type {
933 ResponseStreamType::Toolcall(stream) => parse_tool(conn, stream, &user_context).await?,
934 ResponseStreamType::TextResponse(stream) => {
935 return parse_and_stream_to_user(
936 conn,
937 stream,
938 conversation_id,
939 next_message_order_number,
940 pool,
941 request_estimated_tokens,
942 )
943 .await;
944 }
945 };
946 (chat_request, next_message_order_number) = chat_request
947 .update_messages_to_db(
948 conn,
949 new_tool_msgs,
950 conversation_id,
951 next_message_order_number,
952 )
953 .await?;
954 }
955}