1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::sync::{
5 Arc,
6 atomic::{self, AtomicBool},
7};
8use std::task::{Context, Poll};
9
10use anyhow::{Error, Ok, anyhow};
11use bytes::Bytes;
12use chrono::Utc;
13use futures::stream::{BoxStream, Peekable};
14use futures::{Stream, StreamExt, TryStreamExt};
15use headless_lms_models::chatbot_configurations::{ReasoningEffortLevel, VerbosityLevel};
16use headless_lms_models::chatbot_conversation_messages::{
17 self, ChatbotConversationMessage, MessageRole,
18};
19use headless_lms_models::chatbot_conversation_messages_citations::ChatbotConversationMessageCitation;
20use headless_lms_utils::ApplicationConfiguration;
21use pin_project::pin_project;
22use serde::{Deserialize, Serialize};
23use sqlx::PgPool;
24use tokio::{io::AsyncBufReadExt, sync::Mutex};
25use tokio_stream::wrappers::LinesStream;
26use tokio_util::io::StreamReader;
27use tracing::trace;
28use url::Url;
29
30use crate::chatbot_error::ChatbotResult;
31use crate::chatbot_tools::{
32 AzureLLMToolDefinition, ChatbotTool, get_chatbot_tool, get_chatbot_tool_definitions,
33};
34use crate::llm_utils::{
35 APIMessage, APIMessageKind, APIMessageText, APIMessageToolCall, APIMessageToolResponse,
36 APITool, APIToolCall, estimate_tokens, make_streaming_llm_request,
37};
38use headless_lms_utils::url_encoding::url_decode;
39
40use crate::prelude::*;
41use crate::search_filter::SearchFilter;
42
43const CONTENT_FIELD_SEPARATOR: &str = ",|||,";
44
45pub struct ChatbotUserContext {
48 pub user_id: Uuid,
49 pub course_id: Uuid,
50 pub course_name: String,
51}
52
53#[derive(Deserialize, Serialize, Debug)]
54pub struct ContentFilterResults {
55 pub hate: Option<ContentFilter>,
56 pub self_harm: Option<ContentFilter>,
57 pub sexual: Option<ContentFilter>,
58 pub violence: Option<ContentFilter>,
59}
60
61#[derive(Deserialize, Serialize, Debug)]
62pub struct ContentFilter {
63 pub filtered: bool,
64 pub severity: String,
65}
66
67#[derive(Deserialize, Serialize, Debug)]
69pub struct Choice {
70 pub content_filter_results: Option<ContentFilterResults>,
71 pub delta: Option<Delta>,
72 pub finish_reason: Option<String>,
73 pub index: i32,
74}
75
76#[derive(Deserialize, Serialize, Debug)]
78pub struct Delta {
79 pub content: Option<String>,
80 pub context: Option<DeltaContext>,
81 pub tool_calls: Option<Vec<ToolCallInDelta>>,
82}
83
84#[derive(Deserialize, Serialize, Debug)]
85pub struct DeltaContext {
86 pub citations: Vec<Citation>,
87}
88
89#[derive(Deserialize, Serialize, Debug)]
91pub struct ToolCallInDelta {
92 pub id: Option<String>,
93 pub function: DeltaTool,
94 #[serde(rename = "type")]
95 pub tool_type: Option<ToolCallType>,
96}
97
98#[derive(Deserialize, Serialize, Debug, Clone)]
100pub struct DeltaTool {
101 #[serde(default)]
102 pub arguments: String,
103 pub name: Option<String>,
104}
105
106#[derive(Serialize, Deserialize, Debug, Clone)]
107#[serde(rename_all = "snake_case")]
108pub enum ToolCallType {
109 Function,
110}
111
112#[derive(Deserialize, Serialize, Debug)]
113pub struct Citation {
114 pub content: String,
115 pub title: String,
116 pub url: String,
117 pub filepath: String,
118}
119
120#[derive(Deserialize, Serialize, Debug)]
122pub struct ResponseChunk {
123 pub choices: Vec<Choice>,
124 pub created: u64,
125 pub id: String,
126 pub model: String,
127 pub object: String,
128 pub system_fingerprint: Option<String>,
129}
130
131#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
132#[serde(rename_all = "snake_case")]
133pub enum LLMToolChoice {
134 Auto,
135}
136
137#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
138pub struct ThinkingParams {
139 pub max_completion_tokens: Option<i32>,
140 pub verbosity: Option<VerbosityLevel>,
141 pub reasoning_effort: Option<ReasoningEffortLevel>,
142 #[serde(skip_serializing_if = "Vec::is_empty")]
143 pub tools: Vec<AzureLLMToolDefinition>,
144 pub tool_choice: Option<LLMToolChoice>,
145}
146
147#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
148pub struct NonThinkingParams {
149 pub max_tokens: Option<i32>,
150 pub temperature: Option<f32>,
151 pub top_p: Option<f32>,
152 pub frequency_penalty: Option<f32>,
153 pub presence_penalty: Option<f32>,
154}
155
156#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
157#[serde(untagged)]
158pub enum LLMRequestParams {
159 Thinking(ThinkingParams),
160 NonThinking(NonThinkingParams),
161}
162
163#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
164#[serde(rename_all = "snake_case")]
165pub enum JSONType {
166 JsonSchema,
167 Object,
168 Array,
169 String,
170}
171
172#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
174pub struct JSONSchema {
175 pub name: String,
176 pub strict: bool,
177 pub schema: Schema,
178}
179
180#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
182#[serde(rename_all = "camelCase")]
183pub struct Schema {
184 #[serde(rename = "type")]
185 pub type_field: JSONType,
187 pub properties: HashMap<String, ArrayProperty>,
189 pub required: Vec<String>,
191 pub additional_properties: bool,
193}
194
195#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
196pub struct ArrayProperty {
197 #[serde(rename = "type")]
198 pub type_field: JSONType,
199 pub items: ArrayItem,
200}
201
202#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
203pub struct ArrayItem {
204 #[serde(rename = "type")]
205 pub type_field: JSONType,
206}
207
208#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
209pub struct LLMRequestResponseFormatParam {
210 #[serde(rename = "type")]
211 pub format_type: JSONType, pub json_schema: JSONSchema,
213}
214
215#[derive(Serialize, Deserialize, Debug, Clone)]
216pub struct LLMRequest {
217 pub messages: Vec<APIMessage>,
218 #[serde(skip_serializing_if = "Vec::is_empty", default)]
219 pub data_sources: Vec<DataSource>,
220 #[serde(flatten)]
221 pub params: LLMRequestParams,
222 #[serde(skip_serializing_if = "Option::is_none")]
223 pub response_format: Option<LLMRequestResponseFormatParam>,
224 pub stop: Option<String>,
225}
226
227impl LLMRequest {
228 pub async fn build_and_insert_incoming_message_to_db(
229 conn: &mut PgConnection,
230 chatbot_configuration_id: Uuid,
231 conversation_id: Uuid,
232 message: &str,
233 app_config: &ApplicationConfiguration,
234 ) -> anyhow::Result<(Self, i32, i32)> {
235 let index_name = Url::parse(&app_config.base_url)?
236 .host_str()
237 .expect("BASE_URL must have a host")
238 .replace(".", "-");
239
240 let configuration =
241 models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
242
243 let conversation_messages =
244 models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
245 .await?;
246
247 let new_order_number = conversation_messages
248 .iter()
249 .map(|m| m.order_number)
250 .max()
251 .unwrap_or(0)
252 + 1;
253
254 let new_message = models::chatbot_conversation_messages::insert(
255 conn,
256 ChatbotConversationMessage {
257 id: Uuid::new_v4(),
258 created_at: Utc::now(),
259 updated_at: Utc::now(),
260 deleted_at: None,
261 conversation_id,
262 message: Some(message.to_string()),
263 message_role: MessageRole::User,
264 message_is_complete: true,
265 used_tokens: estimate_tokens(message),
266 order_number: new_order_number,
267 tool_call_fields: vec![],
268 tool_output: None,
269 },
270 )
271 .await?;
272
273 let mut api_chat_messages: Vec<APIMessage> = conversation_messages
274 .into_iter()
275 .map(APIMessage::try_from)
276 .collect::<ChatbotResult<Vec<_>>>()?;
277
278 api_chat_messages.push(new_message.clone().try_into()?);
280
281 api_chat_messages.insert(
282 0,
283 APIMessage {
284 role: MessageRole::System,
285 fields: APIMessageKind::Text(APIMessageText {
286 content: configuration.prompt.clone(),
287 }),
288 },
289 );
290
291 let data_sources = if configuration.use_azure_search {
292 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
293 anyhow::anyhow!("Azure configuration is missing from the application configuration")
294 })?;
295
296 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
297 anyhow::anyhow!(
298 "Azure search configuration is missing from the Azure configuration"
299 )
300 })?;
301
302 let query_type = if configuration.use_semantic_reranking {
303 "vector_semantic_hybrid"
304 } else {
305 "vector_simple_hybrid"
306 };
307
308 api_chat_messages = api_chat_messages
313 .into_iter()
314 .filter(|m| !matches!(m.fields, APIMessageKind::ToolCall(_)))
315 .map(|m| match m.fields {
316 APIMessageKind::ToolResponse(r) => APIMessage {
317 role: MessageRole::Assistant,
318 fields: APIMessageKind::Text(APIMessageText { content: r.content }),
319 },
320 _ => m,
321 })
322 .collect();
323
324 vec![DataSource {
325 data_type: "azure_search".to_string(),
326 parameters: DataSourceParameters {
327 endpoint: search_config.search_endpoint.to_string(),
328 authentication: DataSourceParametersAuthentication {
329 auth_type: "api_key".to_string(),
330 key: search_config.search_api_key.clone(),
331 },
332 index_name,
333 query_type: query_type.to_string(),
334 semantic_configuration: "default".to_string(),
335 embedding_dependency: EmbeddingDependency {
336 dep_type: "deployment_name".to_string(),
337 deployment_name: search_config.vectorizer_deployment_id.clone(),
338 },
339 in_scope: false,
340 top_n_documents: 15,
341 strictness: 3,
342 filter: Some(
343 SearchFilter::eq("course_id", configuration.course_id.to_string())
344 .to_odata()?,
345 ),
346 fields_mapping: FieldsMapping {
347 content_fields_separator: CONTENT_FIELD_SEPARATOR.to_string(),
348 content_fields: vec!["chunk_context".to_string(), "chunk".to_string()],
349 filepath_field: "filepath".to_string(),
350 title_field: "title".to_string(),
351 url_field: "url".to_string(),
352 vector_fields: vec!["text_vector".to_string()],
353 },
354 },
355 }]
356 } else {
357 Vec::new()
358 };
359
360 let tools = if configuration.use_tools {
361 get_chatbot_tool_definitions()
362 } else {
363 Vec::new()
364 };
365
366 let serialized_messages = serde_json::to_string(&api_chat_messages)?;
367 let request_estimated_tokens = estimate_tokens(&serialized_messages);
368
369 let params = if configuration.thinking_model {
370 LLMRequestParams::Thinking(ThinkingParams {
371 max_completion_tokens: Some(configuration.max_completion_tokens),
372 reasoning_effort: Some(configuration.reasoning_effort),
373 verbosity: Some(configuration.verbosity),
374 tools,
375 tool_choice: if configuration.use_tools {
376 Some(LLMToolChoice::Auto)
377 } else {
378 None
379 },
380 })
381 } else {
382 LLMRequestParams::NonThinking(NonThinkingParams {
383 max_tokens: Some(configuration.response_max_tokens),
384 temperature: Some(configuration.temperature),
385 top_p: Some(configuration.top_p),
386 frequency_penalty: Some(configuration.frequency_penalty),
387 presence_penalty: Some(configuration.presence_penalty),
388 })
389 };
390
391 Ok((
392 Self {
393 messages: api_chat_messages,
394 data_sources,
395 params,
396 response_format: None,
397 stop: None,
398 },
399 new_message.order_number,
400 request_estimated_tokens,
401 ))
402 }
403
404 pub async fn update_messages_to_db(
405 mut self,
406 conn: &mut PgConnection,
407 new_msgs: Vec<APIMessage>,
408 conversation_id: Uuid,
409 mut order_number: i32,
410 ) -> anyhow::Result<(Self, i32)> {
411 for m in new_msgs {
412 let converted_msg = m.to_chatbot_conversation_message(conversation_id, order_number)?;
413 chatbot_conversation_messages::insert(conn, converted_msg).await?;
414 self.messages.push(m);
415 order_number += 1;
416 }
417 Ok((self, order_number))
418 }
419}
420
421#[derive(Serialize, Deserialize, Debug, Clone)]
422pub struct DataSource {
423 #[serde(rename = "type")]
424 pub data_type: String,
425 pub parameters: DataSourceParameters,
426}
427
428#[derive(Serialize, Deserialize, Debug, Clone)]
429pub struct DataSourceParameters {
430 pub endpoint: String,
431 pub authentication: DataSourceParametersAuthentication,
432 pub index_name: String,
433 pub query_type: String,
434 pub embedding_dependency: EmbeddingDependency,
435 pub in_scope: bool,
436 pub top_n_documents: i32,
437 pub strictness: i32,
438 #[serde(skip_serializing_if = "Option::is_none")]
439 pub filter: Option<String>,
440 pub fields_mapping: FieldsMapping,
441 pub semantic_configuration: String,
442}
443
444#[derive(Serialize, Deserialize, Debug, Clone)]
445pub struct DataSourceParametersAuthentication {
446 #[serde(rename = "type")]
447 pub auth_type: String,
448 pub key: String,
449}
450
451#[derive(Serialize, Deserialize, Debug, Clone)]
452pub struct EmbeddingDependency {
453 #[serde(rename = "type")]
454 pub dep_type: String,
455 pub deployment_name: String,
456}
457
458#[derive(Serialize, Deserialize, Debug, Clone)]
459pub struct FieldsMapping {
460 pub content_fields_separator: String,
461 pub content_fields: Vec<String>,
462 pub filepath_field: String,
463 pub title_field: String,
464 pub url_field: String,
465 pub vector_fields: Vec<String>,
466}
467
468#[derive(Serialize, Deserialize, Debug)]
469pub struct ChatResponse {
470 pub text: String,
471}
472
473#[pin_project]
475struct GuardedStream<S> {
476 guard: RequestCancelledGuard,
477 #[pin]
478 stream: S,
479}
480
481impl<S> GuardedStream<S> {
482 fn new(guard: RequestCancelledGuard, stream: S) -> Self {
483 Self { guard, stream }
484 }
485}
486
487impl<S> Stream for GuardedStream<S>
488where
489 S: Stream<Item = anyhow::Result<Bytes>> + Send,
490{
491 type Item = S::Item;
492
493 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
494 let this = self.project();
495 this.stream.poll_next(cx)
496 }
497}
498
499type PeekableLinesStream<'a> = Pin<
502 Box<Peekable<LinesStream<StreamReader<BoxStream<'a, Result<Bytes, std::io::Error>>, Bytes>>>>,
503>;
504pub enum ResponseStreamType<'a> {
505 Toolcall(PeekableLinesStream<'a>),
506 TextResponse(PeekableLinesStream<'a>),
507}
508
509struct RequestCancelledGuard {
510 response_message_id: Uuid,
511 received_string: Arc<Mutex<Vec<String>>>,
512 pool: PgPool,
513 done: Arc<AtomicBool>,
514 request_estimated_tokens: i32,
515}
516
517impl Drop for RequestCancelledGuard {
518 fn drop(&mut self) {
519 if self.done.load(atomic::Ordering::Relaxed) {
520 return;
521 }
522 warn!("Request was not cancelled. Cleaning up.");
523 let response_message_id = self.response_message_id;
524 let received_string = self.received_string.clone();
525 let pool = self.pool.clone();
526 let request_estimated_tokens = self.request_estimated_tokens;
527 tokio::spawn(async move {
528 info!("Verifying the received message has been handled");
529 let mut conn = pool.acquire().await.expect("Could not acquire connection");
530 let full_response_text = received_string.lock().await;
531 if full_response_text.is_empty() {
532 info!("No response received. Deleting the response message");
533 models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
534 .await
535 .expect("Could not delete response message");
536 return;
537 }
538 info!("Response received but not completed. Saving the text received so far.");
539 let full_response_as_string = full_response_text.join("");
540 let estimated_cost = estimate_tokens(&full_response_as_string);
541 info!(
542 "End of chatbot response stream. Estimated cost: {}. Response: {}",
543 estimated_cost, full_response_as_string
544 );
545
546 models::chatbot_conversation_messages::update(
548 &mut conn,
549 response_message_id,
550 &full_response_as_string,
551 true,
552 request_estimated_tokens + estimated_cost,
553 )
554 .await
555 .expect("Could not update response message");
556 });
557 }
558}
559
560pub async fn make_request_and_stream<'a>(
561 chat_request: LLMRequest,
562 model_name: &str,
563 app_config: &ApplicationConfiguration,
564) -> anyhow::Result<ResponseStreamType<'a>> {
565 let response = make_streaming_llm_request(chat_request, model_name, app_config).await?;
566
567 trace!("Receiving chat response with {:?}", response.version());
568
569 if !response.status().is_success() {
570 let status = response.status();
571 let error_message = response.text().await?;
572 return Err(anyhow::anyhow!(
573 "Failed to send chat request. Status: {}. Error: {}",
574 status,
575 error_message
576 ));
577 }
578
579 let stream = response
580 .bytes_stream()
581 .map_err(std::io::Error::other)
582 .boxed();
583 let reader = StreamReader::new(stream);
584 let lines = reader.lines();
585 let lines_stream = LinesStream::new(lines);
586 let peekable_lines_stream = lines_stream.peekable();
587 let mut pinned_lines = Box::pin(peekable_lines_stream);
588
589 loop {
590 let line_res = pinned_lines.as_mut().peek().await;
591 match line_res {
592 None => {
593 break;
594 }
595 Some(Err(e)) => {
596 return Err(anyhow!(
597 "There was an error streaming response from Azure: {}",
598 e
599 ));
600 }
601 Some(Result::Ok(line)) => {
602 if !line.starts_with("data: ") {
603 pinned_lines.next().await;
604 continue;
605 }
606 let json_str = line.trim_start_matches("data: ");
607 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
608 .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {}", e))?;
609 for choice in &response_chunk.choices {
610 if let Some(d) = &choice.delta {
611 if d.content.is_some() || d.context.is_some() {
612 return Ok(ResponseStreamType::TextResponse(pinned_lines));
613 } else if let Some(_calls) = &d.tool_calls {
614 return Ok(ResponseStreamType::Toolcall(pinned_lines));
615 } else if d.content.is_none() {
616 pinned_lines.next().await;
617 continue;
618 }
619 }
620 }
621 pinned_lines.next().await;
622 }
623 }
624 }
625 Err(Error::msg(
626 "The response received from Azure had an unexpected shape and couldn't be parsed"
627 .to_string(),
628 ))
629}
630
631pub async fn parse_tool<'a>(
634 conn: &mut PgConnection,
635 mut lines: PeekableLinesStream<'a>,
636 user_context: &ChatbotUserContext,
637) -> anyhow::Result<Vec<APIMessage>> {
638 let mut function_name_id_args: Vec<(String, String, String)> = vec![];
639 let mut currently_streamed_function_name_id: Option<(String, String)> = None;
640 let mut currently_streamed_function_args = vec![];
641 let mut messages = vec![];
642
643 trace!("Parsing tool calls...");
644
645 while let Some(val) = lines.next().await {
646 let line = val?;
647 if !line.to_owned().starts_with("data: ") {
648 continue;
649 }
650 let json_str = line.trim_start_matches("data: ");
651 if json_str.trim() == "[DONE]" {
652 if function_name_id_args.is_empty() {
654 return Err(anyhow::anyhow!(
655 "The LLM response was supposed to contain function calls, but no function calls were found"
656 ));
657 }
658 let mut assistant_tool_calls = Vec::new();
659 let mut tool_result_msgs = Vec::new();
660
661 for (name, id, args) in function_name_id_args.iter() {
662 let tool = get_chatbot_tool(conn, name, args, user_context).await?;
663
664 assistant_tool_calls.push(APIToolCall {
665 function: APITool {
666 name: name.to_owned(),
667 arguments: serde_json::to_string(tool.get_arguments())?,
668 },
669 id: id.to_owned(),
670 tool_type: ToolCallType::Function,
671 });
672 tool_result_msgs.push(APIMessage {
673 role: MessageRole::Tool,
674 fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
675 content: tool.get_tool_output(),
676 name: name.to_owned(),
677 tool_call_id: id.to_owned(),
678 }),
679 })
680 }
681 messages.push(APIMessage {
683 role: MessageRole::Assistant,
684 fields: APIMessageKind::ToolCall(APIMessageToolCall {
685 tool_calls: assistant_tool_calls,
686 }),
687 });
688 messages.extend(tool_result_msgs);
690 break;
691 }
692 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str)
693 .map_err(|e| anyhow::anyhow!("Failed to parse response chunk: {} {}", e, json_str))?;
694 for choice in &response_chunk.choices {
695 if Some("tool_calls".to_string()) == choice.finish_reason {
696 if let Some((name, id)) = ¤tly_streamed_function_name_id {
699 let fn_args = currently_streamed_function_args.join("");
702 function_name_id_args.push((
703 name.to_owned(),
704 id.to_owned(),
705 fn_args.to_owned(),
706 ));
707 currently_streamed_function_args.clear();
708 currently_streamed_function_name_id = None;
709 }
712 }
713 if let Some(delta) = &choice.delta
714 && let Some(tool_calls) = &delta.tool_calls
715 {
716 for call in tool_calls {
718 if let (Some(name), Some(id)) = (&call.function.name, &call.id) {
719 if let Some((name_prev, id_prev)) = currently_streamed_function_name_id {
724 let fn_args = currently_streamed_function_args.join("");
725 function_name_id_args.push((
726 name_prev.to_owned(),
727 id_prev.to_owned(),
728 fn_args,
729 ));
730 currently_streamed_function_args.clear();
731 }
732 currently_streamed_function_name_id =
736 Some((name.to_owned(), id.to_owned()));
737 };
738 currently_streamed_function_args.push(call.function.arguments.clone());
741 }
742 }
743 }
744 }
745 Ok(messages)
746}
747
748pub async fn parse_and_stream_to_user<'a>(
750 conn: &mut PgConnection,
751 mut lines: PeekableLinesStream<'a>,
752 conversation_id: Uuid,
753 response_order_number: i32,
754 pool: PgPool,
755 request_estimated_tokens: i32,
756) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send + 'a>>> {
757 let response_message = models::chatbot_conversation_messages::insert(
759 conn,
760 ChatbotConversationMessage {
761 id: Uuid::new_v4(),
762 created_at: Utc::now(),
763 updated_at: Utc::now(),
764 deleted_at: None,
765 conversation_id,
766 message: Some("".to_string()),
767 message_role: MessageRole::Assistant,
768 message_is_complete: false,
769 used_tokens: request_estimated_tokens,
770 order_number: response_order_number,
771 tool_call_fields: vec![],
772 tool_output: None,
773 },
774 )
775 .await?;
776
777 let done = Arc::new(AtomicBool::new(false));
778 let full_response_text = Arc::new(Mutex::new(Vec::new()));
779 let guard = RequestCancelledGuard {
781 response_message_id: response_message.id,
782 received_string: full_response_text.clone(),
783 pool: pool.clone(),
784 done: done.clone(),
785 request_estimated_tokens,
786 };
787
788 trace!("Parsing stream to user...");
789
790 let response_stream = async_stream::try_stream! {
791 while let Some(val) = lines.next().await {
792 let line = val?;
793 if !line.starts_with("data: ") {
794 continue;
795 }
796 let mut full_response_text = full_response_text.lock().await;
797 let json_str = line.trim_start_matches("data: ");
798 if json_str.trim() == "[DONE]" {
799 let full_response_as_string = full_response_text.join("");
800 let estimated_cost = estimate_tokens(&full_response_as_string);
801 trace!(
802 "End of chatbot response stream. Estimated cost: {}. Response: {}",
803 estimated_cost, full_response_as_string
804 );
805 done.store(true, atomic::Ordering::Relaxed);
806 let mut conn = pool.acquire().await?;
807 models::chatbot_conversation_messages::update(
808 &mut conn,
809 response_message.id,
810 &full_response_as_string,
811 true,
812 request_estimated_tokens + estimated_cost,
813 ).await?;
814 break;
815 }
816 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
817 anyhow::anyhow!("Failed to parse response chunk: {}", e)
818 })?;
819
820 for choice in &response_chunk.choices {
821 if let Some(delta) = &choice.delta {
822 if let Some(content) = &delta.content {
823 full_response_text.push(content.clone());
824 let response = ChatResponse { text: content.clone() };
825 let response_as_string = serde_json::to_string(&response)?;
826 yield Bytes::from(response_as_string);
827 yield Bytes::from("\n");
828 }
829 if let Some(context) = &delta.context {
830 let mut conn = pool.acquire().await?;
831 for (idx, cit) in context.citations.iter().enumerate() {
832 let content = if cit.content.len() < 255 {cit.content.clone()} else {cit.content[0..255].to_string()};
833 let split = content.split_once(CONTENT_FIELD_SEPARATOR);
834 if split.is_none() {
835 error!("Chatbot citation doesn't have any content or is missing 'chunk_context'. Something is wrong with Azure.");
836 }
837 let cleaned_content: String = split.unwrap_or(("","")).1.to_string();
838
839 let decoded_title = url_decode(&cit.title)?;
843 let decoded_url = url_decode(&cit.url)?;
844
845 let mut page_path = PathBuf::from(&cit.filepath);
846 page_path.set_extension("");
847 let page_id_str = page_path.file_name();
848 let page_id = page_id_str.and_then(|id_str| Uuid::parse_str(id_str.to_string_lossy().as_ref()).ok());
849 let course_material_chapter_number = if let Some(id) = page_id {
850 let chapter = models::chapters::get_chapter_by_page_id(&mut conn, id).await.ok();
851 chapter.map(|c| c.chapter_number)
852 } else {
853 None
854 };
855
856 models::chatbot_conversation_messages_citations::insert(
857 &mut conn, ChatbotConversationMessageCitation {
858 id: Uuid::new_v4(),
859 created_at: Utc::now(),
860 updated_at: Utc::now(),
861 deleted_at: None,
862 conversation_message_id: response_message.id,
863 conversation_id: response_message.conversation_id,
864 course_material_chapter_number,
865 title: decoded_title,
866 content: cleaned_content,
867 document_url: decoded_url,
868 citation_number: (idx+1) as i32,
869 }
870 ).await?;
871 }
872 }
873 }
874 }
875 }
876
877 if !done.load(atomic::Ordering::Relaxed) {
878 Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
879 }
880 };
881
882 let guarded_stream = GuardedStream::new(guard, response_stream);
885
886 Ok(Box::pin(guarded_stream))
888}
889
890pub async fn send_chat_request_and_parse_stream(
891 conn: &mut PgConnection,
892 pool: PgPool,
893 app_config: &ApplicationConfiguration,
894 chatbot_configuration_id: Uuid,
895 conversation_id: Uuid,
896 message: &str,
897 user_context: ChatbotUserContext,
898) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
899 let (mut chat_request, new_message_order_number, request_estimated_tokens) =
900 LLMRequest::build_and_insert_incoming_message_to_db(
901 conn,
902 chatbot_configuration_id,
903 conversation_id,
904 message,
905 app_config,
906 )
907 .await?;
908
909 let model = models::chatbot_configurations_models::get_by_chatbot_configuration_id(
910 conn,
911 chatbot_configuration_id,
912 )
913 .await?;
914
915 let mut next_message_order_number = new_message_order_number + 1;
916 let mut max_iterations_left = 15;
917
918 loop {
919 max_iterations_left -= 1;
920 if max_iterations_left == 0 {
921 error!("Maximum tool call iterations exceeded");
922 return Err(anyhow::anyhow!(
923 "Maximum tool call iterations exceeded. The LLM may be stuck in a loop."
924 ));
925 }
926
927 let response_type =
928 make_request_and_stream(chat_request.clone(), &model.deployment_name, app_config)
929 .await?;
930
931 let new_tool_msgs = match response_type {
932 ResponseStreamType::Toolcall(stream) => parse_tool(conn, stream, &user_context).await?,
933 ResponseStreamType::TextResponse(stream) => {
934 return parse_and_stream_to_user(
935 conn,
936 stream,
937 conversation_id,
938 next_message_order_number,
939 pool,
940 request_estimated_tokens,
941 )
942 .await;
943 }
944 };
945 (chat_request, next_message_order_number) = chat_request
946 .update_messages_to_db(
947 conn,
948 new_tool_msgs,
949 conversation_id,
950 next_message_order_number,
951 )
952 .await?;
953 }
954}