1use std::path::PathBuf;
2use std::pin::Pin;
3use std::sync::{
4 Arc,
5 atomic::{self, AtomicBool},
6};
7use std::task::{Context, Poll};
8
9use anyhow::{Error, Ok, anyhow};
10use bytes::Bytes;
11use chrono::Utc;
12use futures::stream::{BoxStream, Peekable};
13use futures::{Stream, StreamExt, TryStreamExt};
14use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
15use headless_lms_models::chatbot_conversation_messages::{
16 self, ChatbotConversationMessage, MessageRole,
17};
18use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
19use headless_lms_utils::ApplicationConfiguration;
20use pin_project::pin_project;
21use serde::{Deserialize, Serialize};
22use sqlx::PgPool;
23use tokio::{io::AsyncBufReadExt, sync::Mutex};
24use tokio_stream::wrappers::LinesStream;
25use tokio_util::io::StreamReader;
26use tracing::trace;
27use url::Url;
28
29use crate::chatbot_error::ChatbotResult;
30use crate::chatbot_tools::{
31 AzureLLMToolDefinition, ChatbotTool, get_chatbot_tool, get_chatbot_tool_definitions,
32};
33use crate::llm_utils::{
34 APIMessage, APIMessageKind, APIMessageText, APIMessageToolCall, APIMessageToolResponse,
35 APITool, APIToolCall, estimate_tokens, make_streaming_llm_request,
36};
37use headless_lms_utils::url_encoding::url_decode;
38
39use crate::prelude::*;
40use crate::search_filter::SearchFilter;
41
42const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
43
44pub struct ChatbotUserContext {
47 pub user_id: Uuid,
48 pub course_id: Uuid,
49 pub course_name: String,
50}
51
52#[derive(Deserialize, Serialize, Debug)]
53pub struct ContentFilterResults {
54 pub hate: Option<ContentFilter>,
55 pub self_harm: Option<ContentFilter>,
56 pub sexual: Option<ContentFilter>,
57 pub violence: Option<ContentFilter>,
58}
59
60#[derive(Deserialize, Serialize, Debug)]
61pub struct ContentFilter {
62 pub filtered: bool,
63 pub severity: String,
64}
65
66#[derive(Deserialize, Serialize, Debug)]
68pub struct Choice {
69 pub content_filter_results: Option<ContentFilterResults>,
70 pub delta: Option<Delta>,
71 pub finish_reason: Option<String>,
72 pub index: i32,
73}
74
75#[derive(Deserialize, Serialize, Debug)]
77pub struct Delta {
78 pub content: Option<String>,
79 pub context: Option<DeltaContext>,
80 pub tool_calls: Option<Vec<ToolCallInDelta>>,
81}
82
83#[derive(Deserialize, Serialize, Debug)]
84pub struct DeltaContext {
85 pub citations: Vec<Citation>,
86}
87
88#[derive(Deserialize, Serialize, Debug)]
90pub struct ToolCallInDelta {
91 pub id: Option<String>,
92 pub function: DeltaTool,
93 #[serde(rename = "type")]
94 pub tool_type: Option<ToolCallType>,
95}
96
97#[derive(Deserialize, Serialize, Debug, Clone)]
99pub struct DeltaTool {
100 #[serde(default)]
101 pub arguments: String,
102 pub name: Option<String>,
103}
104
105#[derive(Serialize, Deserialize, Debug, Clone)]
106#[serde(rename_all = "snake_case")]
107pub enum ToolCallType {
108 Function,
109}
110
111#[derive(Deserialize, Serialize, Debug)]
112pub struct Citation {
113 pub content: String,
114 pub title: String,
115 pub url: String,
116 pub filepath: String,
117}
118
119#[derive(Deserialize, Serialize, Debug)]
121pub struct ResponseChunk {
122 pub choices: Vec<Choice>,
123 pub created: u64,
124 pub id: String,
125 pub model: String,
126 pub object: String,
127 pub system_fingerprint: Option<String>,
128}
129
130#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
131#[serde(rename_all = "snake_case")]
132pub enum LLMToolChoice {
133 Auto,
134}
135
136#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
137pub struct ThinkingParams {
138 pub max_completion_tokens: Option<i32>,
139 pub verbosity: Option<VerbosityLevel>,
140 pub reasoning_effort: Option<ReasoningEffortLevel>,
141 #[serde(skip_serializing_if = "Vec::is_empty")]
142 pub tools: Vec<AzureLLMToolDefinition>,
143 pub tool_choice: Option<LLMToolChoice>,
144}
145
146#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
147pub struct NonThinkingParams {
148 pub max_tokens: Option<i32>,
149 pub temperature: Option<f32>,
150 pub top_p: Option<f32>,
151 pub frequency_penalty: Option<f32>,
152 pub presence_penalty: Option<f32>,
153}
154
155#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
156#[serde(untagged)]
157pub enum LLMRequestParams {
158 Thinking(ThinkingParams),
159 NonThinking(NonThinkingParams),
160}
161
162#[derive(Serialize, Deserialize, Debug, Clone)]
163pub struct LLMRequest {
164 pub messages: Vec<APIMessage>,
165 #[serde(skip_serializing_if = "Vec::is_empty")]
166 pub data_sources: Vec<DataSource>,
167 #[serde(flatten)]
168 pub params: LLMRequestParams,
169 pub stop: Option<String>,
170}
171
172impl LLMRequest {
173 pub async fn build_and_insert_incoming_message_to_db(
174 conn: &mut PgConnection,
175 chatbot_configuration_id: Uuid,
176 conversation_id: Uuid,
177 message: &str,
178 app_config: &ApplicationConfiguration,
179 ) -> anyhow::Result<(Self, i32, i32)> {
180 let index_name = Url::parse(&app_config.base_url)?
181 .host_str()
182 .expect("BASE_URL must have a host")
183 .replace(".", "-");
184
185 let configuration =
186 models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
187
188 let conversation_messages =
189 models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
190 .await?;
191
192 let new_order_number = conversation_messages
193 .iter()
194 .map(|m| m.order_number)
195 .max()
196 .unwrap_or(0)
197 + 1;
198
199 let new_message = models::chatbot_conversation_messages::insert(
200 conn,
201 ChatbotConversationMessage {
202 id: Uuid::new_v4(),
203 created_at: Utc::now(),
204 updated_at: Utc::now(),
205 deleted_at: None,
206 conversation_id,
207 message: Some(message.to_string()),
208 message_role: MessageRole::User,
209 message_is_complete: true,
210 used_tokens: estimate_tokens(message),
211 order_number: new_order_number,
212 tool_call_fields: vec![],
213 tool_output: None,
214 },
215 )
216 .await?;
217
218 let mut api_chat_messages: Vec<APIMessage> = conversation_messages
219 .into_iter()
220 .map(APIMessage::try_from)
221 .collect::<ChatbotResult<Vec<_>>>()?;
222
223 api_chat_messages.push(new_message.clone().try_into()?);
225
226 api_chat_messages.insert(
227 0,
228 APIMessage {
229 role: MessageRole::System,
230 fields: APIMessageKind::Text(APIMessageText {
231 content: configuration.prompt.clone(),
232 }),
233 },
234 );
235
236 let data_sources = if configuration.use_azure_search {
237 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
238 anyhow::anyhow!("Azure configuration is missing from the application configuration")
239 })?;
240
241 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
242 anyhow::anyhow!(
243 "Azure search configuration is missing from the Azure configuration"
244 )
245 })?;
246
247 let query_type = if configuration.use_semantic_reranking {
248 "vector_semantic_hybrid"
249 } else {
250 "vector_simple_hybrid"
251 };
252
253 api_chat_messages = api_chat_messages
258 .into_iter()
259 .filter(|m| !matches!(m.fields, APIMessageKind::ToolCall(_)))
260 .map(|m| match m.fields {
261 APIMessageKind::ToolResponse(r) => APIMessage {
262 role: MessageRole::Assistant,
263 fields: APIMessageKind::Text(APIMessageText { content: r.content }),
264 },
265 _ => m,
266 })
267 .collect();
268
269 vec![DataSource {
270 data_type: "azure_search".to_string(),
271 parameters: DataSourceParameters {
272 endpoint: search_config.search_endpoint.to_string(),
273 authentication: DataSourceParametersAuthentication {
274 auth_type: "api_key".to_string(),
275 key: search_config.search_api_key.clone(),
276 },
277 index_name,
278 query_type: query_type.to_string(),
279 semantic_configuration: "default".to_string(),
280 embedding_dependency: EmbeddingDependency {
281 dep_type: "deployment_name".to_string(),
282 deployment_name: search_config.vectorizer_deployment_id.clone(),
283 },
284 in_scope: false,
285 top_n_documents: 15,
286 strictness: 3,
287 filter: Some(
288 SearchFilter::eq("course_id", configuration.course_id.to_string())
289 .to_odata()?,
290 ),
291 fields_mapping: FieldsMapping {
292 content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
293 content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
294 filepath_field: "filepath".to_string(),
295 title_field: "title".to_string(),
296 url_field: "url".to_string(),
297 vector_fields: vec!["text_vector".to_string()],
298 },
299 },
300 }]
301 } else {
302 Vec::new()
303 };
304
305 let tools = if configuration.use_tools {
306 get_chatbot_tool_definitions()
307 } else {
308 Vec::new()
309 };
310
311 let serialized_messages = serde_json::to_string(&api_chat_messages)?;
312 let request_estimated_tokens = estimate_tokens(&serialized_messages);
313
314 let params = if configuration.thinking_model {
315 LLMRequestParams::Thinking(ThinkingParams {
316 max_completion_tokens: Some(configuration.max_completion_tokens),
317 reasoning_effort: Some(configuration.reasoning_effort),
318 verbosity: Some(configuration.verbosity),
319 tools,
320 tool_choice: if configuration.use_tools {
321 Some(LLMToolChoice::Auto)
322 } else {
323 None
324 },
325 })
326 } else {
327 LLMRequestParams::NonThinking(NonThinkingParams {
328 max_tokens: Some(configuration.response_max_tokens),
329 temperature: Some(configuration.temperature),
330 top_p: Some(configuration.top_p),
331 frequency_penalty: Some(configuration.frequency_penalty),
332 presence_penalty: Some(configuration.presence_penalty),
333 })
334 };
335
336 Ok((
337 Self {
338 messages: api_chat_messages,
339 data_sources,
340 params,
341 stop: None,
342 },
343 new_message.order_number,
344 request_estimated_tokens,
345 ))
346 }
347
348 pub async fn update_messages_to_db(
349 mut self,
350 conn: &mut PgConnection,
351 new_msgs: Vec<APIMessage>,
352 conversation_id: Uuid,
353 mut order_number: i32,
354 ) -> anyhow::Result<(Self, i32)> {
355 for m in new_msgs {
356 let converted_msg = m.to_chatbot_conversation_message(conversation_id, order_number)?;
357 chatbot_conversation_messages::insert(conn, converted_msg).await?;
358 self.messages.push(m);
359 order_number += 1;
360 }
361 Ok((self, order_number))
362 }
363}
364
365#[derive(Serialize, Deserialize, Debug, Clone)]
366pub struct DataSource {
367 #[serde(rename = "type")]
368 pub data_type: String,
369 pub parameters: DataSourceParameters,
370}
371
372#[derive(Serialize, Deserialize, Debug, Clone)]
373pub struct DataSourceParameters {
374 pub endpoint: String,
375 pub authentication: DataSourceParametersAuthentication,
376 pub index_name: String,
377 pub query_type: String,
378 pub embedding_dependency: EmbeddingDependency,
379 pub in_scope: bool,
380 pub top_n_documents: i32,
381 pub strictness: i32,
382 #[serde(skip_serializing_if = "Option::is_none")]
383 pub filter: Option<String>,
384 pub fields_mapping: FieldsMapping,
385 pub semantic_configuration: String,
386}
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389pub struct DataSourceParametersAuthentication {
390 #[serde(rename = "type")]
391 pub auth_type: String,
392 pub key: String,
393}
394
395#[derive(Serialize, Deserialize, Debug, Clone)]
396pub struct EmbeddingDependency {
397 #[serde(rename = "type")]
398 pub dep_type: String,
399 pub deployment_name: String,
400}
401
402#[derive(Serialize, Deserialize, Debug, Clone)]
403pub struct FieldsMapping {
404 pub content_fields_separator: String,
405 pub content_fields: Vec<String>,
406 pub filepath_field: String,
407 pub title_field: String,
408 pub url_field: String,
409 pub vector_fields: Vec<String>,
410}
411
412#[derive(Serialize, Deserialize, Debug)]
413pub struct ChatResponse {
414 pub text: String,
415}
416
417#[pin_project]
419struct GuardedStream<S> {
420 guard: RequestCancelledGuard,
421 #[pin]
422 stream: S,
423}
424
425impl<S> GuardedStream<S> {
426 fn new(guard: RequestCancelledGuard, stream: S) -> Self {
427 Self { guard, stream }
428 }
429}
430
431impl<S> Stream for GuardedStream<S>
432where
433 S: Stream<Item = anyhow::Result<Bytes>> + Send,
434{
435 type Item = S::Item;
436
437 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
438 let this = self.project();
439 this.stream.poll_next(cx)
440 }
441}
442
443type PeekableLinesStream<'a> = Pin<
446 Box<Peekable<LinesStream<StreamReader<BoxStream<'a, Result<Bytes, std::io::Error>>, Bytes>>>>,
447>;
448pub enum ResponseStreamType<'a> {
449 Toolcall(PeekableLinesStream<'a>),
450 TextResponse(PeekableLinesStream<'a>),
451}
452
453struct RequestCancelledGuard {
454 response_message_id: Uuid,
455 received_string: Arc<Mutex<Vec<String>>>,
456 pool: PgPool,
457 done: Arc<AtomicBool>,
458 request_estimated_tokens: i32,
459}
460
461impl Drop for RequestCancelledGuard {
462 fn drop(&mut self) {
463 if self.done.load(atomic::Ordering::Relaxed) {
464 return;
465 }
466 warn!("Request was not cancelled. Cleaning up.");
467 let response_message_id = self.response_message_id;
468 let received_string = self.received_string.clone();
469 let pool = self.pool.clone();
470 let request_estimated_tokens = self.request_estimated_tokens;
471 tokio::spawn(async move {
472 info!("Verifying the received message has been handled");
473 let mut conn = pool.acquire().await.expect("Could not acquire connection");
474 let full_response_text = received_string.lock().await;
475 if full_response_text.is_empty() {
476 info!("No response received. Deleting the response message");
477 models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
478 .await
479 .expect("Could not delete response message");
480 return;
481 }
482 info!("Response received but not completed. Saving the text received so far.");
483 let full_response_as_string = full_response_text.join("");
484 let estimated_cost = estimate_tokens(&full_response_as_string);
485 info!(
486 "End of chatbot response stream. Estimated cost: {}. Response: {}",
487 estimated_cost, full_response_as_string
488 );
489
490 models::chatbot_conversation_messages::update(
492 &mut conn,
493 response_message_id,
494 &full_response_as_string,
495 true,
496 request_estimated_tokens + estimated_cost,
497 )
498 .await
499 .expect("Could not update response message");
500 });
501 }
502}
503
504pub async fn make_request_and_stream<'a>(
505 chat_request: LLMRequest,
506 model_name: &str,
507 app_config: &ApplicationConfiguration,
508) -> anyhow::Result<ResponseStreamType<'a>> {
509 let response = make_streaming_llm_request(chat_request, model_name, app_config).await?;
510
511 trace!("Receiving chat response with {:?}", response.version());
512
513 if !response.status().is_success() {
514 let status = response.status();
515 let error_message = response.text().await?;
516 return Err(anyhow::anyhow!(
517 "Failed to send chat request. Status: {}. Error: {}",
518 status,
519 error_message
520 ));
521 }
522
523 let stream = response
524 .bytes_stream()
525 .map_err(std::io::Error::other)
526 .boxed();
527 let reader = StreamReader::new(stream);
528 let lines = reader.lines();
529 let lines_stream = LinesStream::new(lines);
530 let peekable_lines_stream = lines_stream.peekable();
531 let mut pinned_lines = Box::pin(peekable_lines_stream);
532
533 loop {
534 let line_res = pinned_lines.as_mut().peek().await;
535 match line_res {
536 None => {
537 break;
538 }
539 Some(Err(e)) => {
540 return Err(anyhow!(
541 "There was an error streaming response from Azure: {}",
542 e
543 ));
544 }
545 Some(Result::Ok(line)) => {
546 if !line.starts_with("data: ") {
547 pinned_lines.next().await;
548 continue;
549 }
550 let json_str = line.trim_start_matches("data: ");
551 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
552 .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {}", e))?;
553 for choice in &response_chunk.choices {
554 if let Some(d) = &choice.delta {
555 if d.content.is_some() || d.context.is_some() {
556 return Ok(ResponseStreamType::TextResponse(pinned_lines));
557 } else if let Some(_calls) = &d.tool_calls {
558 return Ok(ResponseStreamType::Toolcall(pinned_lines));
559 } else if d.content.is_none() {
560 pinned_lines.next().await;
561 continue;
562 }
563 }
564 }
565 pinned_lines.next().await;
566 }
567 }
568 }
569 Err(Error::msg(
570 "The response received from Azure had an unexpected shape and couldn't be parsed"
571 .to_string(),
572 ))
573}
574
575pub async fn parse_tool<'a>(
578 conn: &mut PgConnection,
579 mut lines: PeekableLinesStream<'a>,
580 user_context: &ChatbotUserContext,
581) -> anyhow::Result<Vec<APIMessage>> {
582 let mut function_name_id_args: Vec<(String, String, String)> = vec![];
583 let mut currently_streamed_function_name_id: Option<(String, String)> = None;
584 let mut currently_streamed_function_args = vec![];
585 let mut messages = vec![];
586
587 trace!("Parsing tool calls...");
588
589 while let Some(val) = lines.next().await {
590 let line = val?;
591 if !line.to_owned().starts_with("data: ") {
592 continue;
593 }
594 let json_str = line.trim_start_matches("data: ");
595 if json_str.trim() == "[DONE]" {
596 if function_name_id_args.is_empty() {
598 return Err(anyhow::anyhow!(
599 "The LLM response was supposed to contain function calls, but no function calls were found"
600 ));
601 }
602 let mut assistant_tool_calls = Vec::new();
603 let mut tool_result_msgs = Vec::new();
604
605 for (name, id, args) in function_name_id_args.iter() {
606 let tool = get_chatbot_tool(conn, name, args, user_context).await?;
607
608 assistant_tool_calls.push(APIToolCall {
609 function: APITool {
610 name: name.to_owned(),
611 arguments: serde_json::to_string(tool.get_arguments())?,
612 },
613 id: id.to_owned(),
614 tool_type: ToolCallType::Function,
615 });
616 tool_result_msgs.push(APIMessage {
617 role: MessageRole::Tool,
618 fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
619 content: tool.get_tool_output(),
620 name: name.to_owned(),
621 tool_call_id: id.to_owned(),
622 }),
623 })
624 }
625 messages.push(APIMessage {
627 role: MessageRole::Assistant,
628 fields: APIMessageKind::ToolCall(APIMessageToolCall {
629 tool_calls: assistant_tool_calls,
630 }),
631 });
632 messages.extend(tool_result_msgs);
634 break;
635 }
636 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
637 .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {} {}", e, json_str))?;
638 for choice in &response_chunk.choices {
639 if Some("tool_calls".to_string()) == choice.finish_reason {
640 if let Some((name, id)) = ¤tly_streamed_function_name_id {
643 let fn_args = currently_streamed_function_args.join("");
646 function_name_id_args.push((
647 name.to_owned(),
648 id.to_owned(),
649 fn_args.to_owned(),
650 ));
651 currently_streamed_function_args.clear();
652 currently_streamed_function_name_id = None;
653 }
656 }
657 if let Some(delta) = &choice.delta
658 && let Some(tool_calls) = &delta.tool_calls
659 {
660 for call in tool_calls {
662 if let (Some(name), Some(id)) = (&call.function.name, &call.id) {
663 if let Some((name_prev, id_prev)) = currently_streamed_function_name_id {
668 let fn_args = currently_streamed_function_args.join("");
669 function_name_id_args.push((
670 name_prev.to_owned(),
671 id_prev.to_owned(),
672 fn_args,
673 ));
674 currently_streamed_function_args.clear();
675 }
676 currently_streamed_function_name_id =
680 Some((name.to_owned(), id.to_owned()));
681 };
682 currently_streamed_function_args.push(call.function.arguments.clone());
685 }
686 }
687 }
688 }
689 Ok(messages)
690}
691
692pub async fn parse_and_stream_to_user<'a>(
694 conn: &mut PgConnection,
695 mut lines: PeekableLinesStream<'a>,
696 conversation_id: Uuid,
697 response_order_number: i32,
698 pool: PgPool,
699 request_estimated_tokens: i32,
700) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send + 'a>>> {
701 let response_message = models::chatbot_conversation_messages::insert(
703 conn,
704 ChatbotConversationMessage {
705 id: Uuid::new_v4(),
706 created_at: Utc::now(),
707 updated_at: Utc::now(),
708 deleted_at: None,
709 conversation_id,
710 message: Some("".to_string()),
711 message_role: MessageRole::Assistant,
712 message_is_complete: false,
713 used_tokens: request_estimated_tokens,
714 order_number: response_order_number,
715 tool_call_fields: vec![],
716 tool_output: None,
717 },
718 )
719 .await?;
720
721 let done = Arc::new(AtomicBool::new(false));
722 let full_response_text = Arc::new(Mutex::new(Vec::new()));
723 let guard = RequestCancelledGuard {
725 response_message_id: response_message.id,
726 received_string: full_response_text.clone(),
727 pool: pool.clone(),
728 done: done.clone(),
729 request_estimated_tokens,
730 };
731
732 trace!("Parsing stream to user...");
733
734 let response_stream = async_stream::try_stream! {
735 while let Some(val) = lines.next().await {
736 let line = val?;
737 if !line.starts_with("data: ") {
738 continue;
739 }
740 let mut full_response_text = full_response_text.lock().await;
741 let json_str = line.trim_start_matches("data: ");
742 if json_str.trim() == "[DONE]" {
743 let full_response_as_string = full_response_text.join("");
744 let estimated_cost = estimate_tokens(&full_response_as_string);
745 trace!(
746 "End of chatbot response stream. Estimated cost: {}. Response: {}",
747 estimated_cost, full_response_as_string
748 );
749 done.store(true, atomic::Ordering::Relaxed);
750 let mut conn = pool.acquire().await?;
751 models::chatbot_conversation_messages::update(
752 &mut conn,
753 response_message.id,
754 &full_response_as_string,
755 true,
756 request_estimated_tokens + estimated_cost,
757 ).await?;
758 break;
759 }
760 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
761 anyhow::anyhow!("Failed to parse response chunk: {}", e)
762 })?;
763
764 for choice in &response_chunk.choices {
765 if let Some(delta) = &choice.delta {
766 if let Some(content) = &delta.content {
767 full_response_text.push(content.clone());
768 let response = ChatResponse { text: content.clone() };
769 let response_as_string = serde_json::to_string(&response)?;
770 yield Bytes::from(response_as_string);
771 yield Bytes::from("\n");
772 }
773 if let Some(context) = &delta.context {
774 let mut conn = pool.acquire().await?;
775 for (idx, cit) in context.citations.iter().enumerate() {
776 let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
777 let split = content.split_once(CONTENT_FIELD_SEPARATOR);
778 if split.is_none() {
779 error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
780 }
781 let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
782
783 let decoded_title = url_decode(&cit.title)?;
787 let decoded_url = url_decode(&cit.url)?;
788
789 let mut page_path = PathBuf::from(&cit.filepath);
790 page_path.set_extension("");
791 let page_id_str = page_path.file_name();
792 let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
793 let course_material_chapter_number = if let Some(id) = page_id {
794 let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
795 chapter.map(|c| c.chapter_number)
796 } else {
797 None
798 };
799
800 models::chatbot_conversation_messages_citations::insert(
801 &mut conn, ChatbotConversationMessageCitation {
802 id: Uuid::new_v4(),
803 created_at: Utc::now(),
804 updated_at: Utc::now(),
805 deleted_at: None,
806 conversation_message_id: response_message.id,
807 conversation_id: response_message.conversation_id,
808 course_material_chapter_number,
809 title: decoded_title,
810 content: cleaned_content,
811 document_url: decoded_url,
812 citation_number: (idx+1) as i32,
813 }
814 ).await?;
815 }
816 }
817 }
818 }
819 }
820
821 if !done.load(atomic::Ordering::Relaxed) {
822 Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
823 }
824 };
825
826 let guarded_stream = GuardedStream::new(guard, response_stream);
829
830 Ok(Box::pin(guarded_stream))
832}
833
834pub async fn send_chat_request_and_parse_stream(
835 conn: &mut PgConnection,
836 pool: PgPool,
837 app_config: &ApplicationConfiguration,
838 chatbot_configuration_id: Uuid,
839 conversation_id: Uuid,
840 message: &str,
841 user_context: ChatbotUserContext,
842) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
843 let (mut chat_request, new_message_order_number, request_estimated_tokens) =
844 LLMRequest::build_and_insert_incoming_message_to_db(
845 conn,
846 chatbot_configuration_id,
847 conversation_id,
848 message,
849 app_config,
850 )
851 .await?;
852
853 let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
854 conn,
855 chatbot_configuration_id,
856 )
857 .await?;
858
859 let mut next_message_order_number = new_message_order_number + 1;
860 let mut max_iterations_left = 15;
861
862 loop {
863 max_iterations_left -= 1;
864 if max_iterations_left == 0 {
865 error!("Maximum tool call iterations exceeded");
866 return Err(anyhow::anyhow!(
867 "Maximum tool call iterations exceeded. The LLM may be stuck in a loop."
868 ));
869 }
870
871 let response_type =
872 make_request_and_stream(chat_request.clone(), &model.deployment_name, app_config)
873 .await?;
874
875 let new_tool_msgs = match response_type {
876 ResponseStreamType::Toolcall(stream) => parse_tool(conn, stream, &user_context).await?,
877 ResponseStreamType::TextResponse(stream) => {
878 return parse_and_stream_to_user(
879 conn,
880 stream,
881 conversation_id,
882 next_message_order_number,
883 pool,
884 request_estimated_tokens,
885 )
886 .await;
887 }
888 };
889 (chat_request, next_message_order_number) = chat_request
890 .update_messages_to_db(
891 conn,
892 new_tool_msgs,
893 conversation_id,
894 next_message_order_number,
895 )
896 .await?;
897 }
898}