1use std::pin::Pin;
2use std::sync::{
3 Arc,
4 atomic::{self, AtomicBool},
5};
6use std::task::{Context, Poll};
7
8use bytes::Bytes;
9use chrono::Utc;
10use futures::{Stream, TryStreamExt};
11use headless_lms_models::chatbot_conversation_messages::ChatbotConversationMessage;
12use headless_lms_utils::{ApplicationConfiguration, http::REQWEST_CLIENT};
13use pin_project::pin_project;
14use serde::{Deserialize, Serialize};
15use sqlx::PgPool;
16use tokio::{io::AsyncBufReadExt, sync::Mutex};
17use tokio_util::io::StreamReader;
18use url::Url;
19
20use crate::llm_utils::{LLM_API_VERSION, build_llm_headers, estimate_tokens};
21use crate::prelude::*;
22use crate::search_filter::SearchFilter;
23
24#[derive(Deserialize, Serialize, Debug)]
25pub struct ContentFilterResults {
26 pub hate: Option<ContentFilter>,
27 pub self_harm: Option<ContentFilter>,
28 pub sexual: Option<ContentFilter>,
29 pub violence: Option<ContentFilter>,
30}
31
32#[derive(Deserialize, Serialize, Debug)]
33pub struct ContentFilter {
34 pub filtered: bool,
35 pub severity: String,
36}
37
38#[derive(Deserialize, Serialize, Debug)]
39pub struct Choice {
40 pub content_filter_results: Option<ContentFilterResults>,
41 pub delta: Option<Delta>,
42 pub finish_reason: Option<String>,
43 pub index: i32,
44}
45
46#[derive(Deserialize, Serialize, Debug)]
47pub struct Delta {
48 pub content: Option<String>,
49}
50
51#[derive(Deserialize, Serialize, Debug)]
52pub struct ResponseChunk {
53 pub choices: Vec<Choice>,
54 pub created: u64,
55 pub id: String,
56 pub model: String,
57 pub object: String,
58 pub system_fingerprint: Option<String>,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
63pub struct ApiChatMessage {
64 pub role: String,
65 pub content: String,
66}
67
68impl From<ChatbotConversationMessage> for ApiChatMessage {
69 fn from(message: ChatbotConversationMessage) -> Self {
70 ApiChatMessage {
71 role: if message.is_from_chatbot {
72 "assistant".to_string()
73 } else {
74 "user".to_string()
75 },
76 content: message.message.unwrap_or_default(),
77 }
78 }
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82pub struct ChatRequest {
83 pub messages: Vec<ApiChatMessage>,
84 #[serde(skip_serializing_if = "Vec::is_empty")]
85 pub data_sources: Vec<DataSource>,
86 pub temperature: f32,
87 pub top_p: f32,
88 pub frequency_penalty: f32,
89 pub presence_penalty: f32,
90 pub max_tokens: i32,
91 pub stop: Option<String>,
92 pub stream: bool,
93}
94
95impl ChatRequest {
96 pub async fn build_and_insert_incoming_message_to_db(
97 conn: &mut PgConnection,
98 chatbot_configuration_id: Uuid,
99 conversation_id: Uuid,
100 message: &str,
101 app_config: &ApplicationConfiguration,
102 ) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
103 let index_name = Url::parse(&app_config.base_url)?
104 .host_str()
105 .expect("BASE_URL must have a host")
106 .replace(".", "-");
107
108 let configuration =
109 models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;
110
111 let conversation_messages =
112 models::chatbot_conversation_messages::get_by_conversation_id(conn, conversation_id)
113 .await?;
114
115 let new_order_number = conversation_messages
116 .iter()
117 .map(|m| m.order_number)
118 .max()
119 .unwrap_or(0)
120 + 1;
121
122 let new_message = models::chatbot_conversation_messages::insert(
123 conn,
124 ChatbotConversationMessage {
125 id: Uuid::new_v4(),
126 created_at: Utc::now(),
127 updated_at: Utc::now(),
128 deleted_at: None,
129 conversation_id,
130 message: Some(message.to_string()),
131 is_from_chatbot: false,
132 message_is_complete: true,
133 used_tokens: estimate_tokens(message),
134 order_number: new_order_number,
135 },
136 )
137 .await?;
138
139 let mut api_chat_messages: Vec<ApiChatMessage> =
140 conversation_messages.into_iter().map(Into::into).collect();
141
142 api_chat_messages.push(new_message.clone().into());
143
144 api_chat_messages.insert(
145 0,
146 ApiChatMessage {
147 role: "system".to_string(),
148 content: configuration.prompt.clone(),
149 },
150 );
151
152 let data_sources = if configuration.use_azure_search {
153 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
154 anyhow::anyhow!("Azure configuration is missing from the application configuration")
155 })?;
156
157 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
158 anyhow::anyhow!(
159 "Azure search configuration is missing from the Azure configuration"
160 )
161 })?;
162
163 let query_type = if configuration.use_semantic_reranking {
164 "vector_semantic_hybrid"
165 } else {
166 "vector_simple_hybrid"
167 };
168
169 vec![DataSource {
170 data_type: "azure_search".to_string(),
171 parameters: DataSourceParameters {
172 endpoint: search_config.search_endpoint.to_string(),
173 authentication: DataSourceParametersAuthentication {
174 auth_type: "api_key".to_string(),
175 key: search_config.search_api_key.clone(),
176 },
177 index_name,
178 query_type: query_type.to_string(),
179 semantic_configuration: "default".to_string(),
180 embedding_dependency: EmbeddingDependency {
181 dep_type: "deployment_name".to_string(),
182 deployment_name: search_config.vectorizer_deployment_id.clone(),
183 },
184 in_scope: false,
185 top_n_documents: 5,
186 strictness: 3,
187 filter: Some(
188 SearchFilter::eq("course_id", configuration.course_id.to_string())
189 .to_odata()?,
190 ),
191 fields_mapping: FieldsMapping {
192 content_fields_separator: ",".to_string(),
193 content_fields: vec!["chunk".to_string()],
194 filepath_field: "filepath".to_string(),
195 title_field: "title".to_string(),
196 url_field: "url".to_string(),
197 vector_fields: vec!["text_vector".to_string()],
198 },
199 },
200 }]
201 } else {
202 Vec::new()
203 };
204
205 let serialized_messages = serde_json::to_string(&api_chat_messages)?;
206 let request_estimated_tokens = estimate_tokens(&serialized_messages);
207
208 Ok((
209 Self {
210 messages: api_chat_messages,
211 data_sources,
212 temperature: configuration.temperature,
213 top_p: configuration.top_p,
214 frequency_penalty: configuration.frequency_penalty,
215 presence_penalty: configuration.presence_penalty,
216 max_tokens: configuration.response_max_tokens,
217 stop: None,
218 stream: true,
219 },
220 new_message,
221 request_estimated_tokens,
222 ))
223 }
224}
225
226#[derive(Serialize, Deserialize, Debug)]
227pub struct DataSource {
228 #[serde(rename = "type")]
229 pub data_type: String,
230 pub parameters: DataSourceParameters,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct DataSourceParameters {
235 pub endpoint: String,
236 pub authentication: DataSourceParametersAuthentication,
237 pub index_name: String,
238 pub query_type: String,
239 pub embedding_dependency: EmbeddingDependency,
240 pub in_scope: bool,
241 pub top_n_documents: i32,
242 pub strictness: i32,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 pub filter: Option<String>,
245 pub fields_mapping: FieldsMapping,
246 pub semantic_configuration: String,
247}
248
249#[derive(Serialize, Deserialize, Debug)]
250pub struct DataSourceParametersAuthentication {
251 #[serde(rename = "type")]
252 pub auth_type: String,
253 pub key: String,
254}
255
256#[derive(Serialize, Deserialize, Debug)]
257pub struct EmbeddingDependency {
258 #[serde(rename = "type")]
259 pub dep_type: String,
260 pub deployment_name: String,
261}
262
263#[derive(Serialize, Deserialize, Debug)]
264pub struct FieldsMapping {
265 pub content_fields_separator: String,
266 pub content_fields: Vec<String>,
267 pub filepath_field: String,
268 pub title_field: String,
269 pub url_field: String,
270 pub vector_fields: Vec<String>,
271}
272
273#[derive(Serialize, Deserialize, Debug)]
274pub struct ChatResponse {
275 pub text: String,
276}
277
278#[pin_project]
280struct GuardedStream<S> {
281 guard: RequestCancelledGuard,
282 #[pin]
283 stream: S,
284}
285
286impl<S> GuardedStream<S> {
287 fn new(guard: RequestCancelledGuard, stream: S) -> Self {
288 Self { guard, stream }
289 }
290}
291
292impl<S> Stream for GuardedStream<S>
293where
294 S: Stream<Item = anyhow::Result<Bytes>> + Send,
295{
296 type Item = S::Item;
297
298 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
299 let this = self.project();
300 this.stream.poll_next(cx)
301 }
302}
303
304struct RequestCancelledGuard {
305 response_message_id: Uuid,
306 received_string: Arc<Mutex<Vec<String>>>,
307 pool: PgPool,
308 done: Arc<AtomicBool>,
309 request_estimated_tokens: i32,
310}
311
312impl Drop for RequestCancelledGuard {
313 fn drop(&mut self) {
314 if self.done.load(atomic::Ordering::Relaxed) {
315 return;
316 }
317 warn!("Request was not cancelled. Cleaning up.");
318 let response_message_id = self.response_message_id;
319 let received_string = self.received_string.clone();
320 let pool = self.pool.clone();
321 let request_estimated_tokens = self.request_estimated_tokens;
322 tokio::spawn(async move {
323 info!("Verifying the received message has been handled");
324 let mut conn = pool.acquire().await.expect("Could not acquire connection");
325 let full_response_text = received_string.lock().await;
326 if full_response_text.is_empty() {
327 info!("No response received. Deleting the response message");
328 models::chatbot_conversation_messages::delete(&mut conn, response_message_id)
329 .await
330 .expect("Could not delete response message");
331 return;
332 }
333 info!("Response received but not completed. Saving the text received so far.");
334 let full_response_as_string = full_response_text.join("");
335 let estimated_cost = estimate_tokens(&full_response_as_string);
336 info!(
337 "End of chatbot response stream. Estimated cost: {}. Response: {}",
338 estimated_cost, full_response_as_string
339 );
340
341 models::chatbot_conversation_messages::update(
343 &mut conn,
344 response_message_id,
345 &full_response_as_string,
346 true,
347 request_estimated_tokens + estimated_cost,
348 )
349 .await
350 .expect("Could not update response message");
351 });
352 }
353}
354
355pub async fn send_chat_request_and_parse_stream(
356 conn: &mut PgConnection,
357 pool: PgPool,
358 app_config: &ApplicationConfiguration,
359 chatbot_configuration_id: Uuid,
360 conversation_id: Uuid,
361 message: &str,
362) -> anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>> {
363 let (chat_request, new_message, request_estimated_tokens) =
364 ChatRequest::build_and_insert_incoming_message_to_db(
365 conn,
366 chatbot_configuration_id,
367 conversation_id,
368 message,
369 app_config,
370 )
371 .await?;
372
373 let full_response_text = Arc::new(Mutex::new(Vec::new()));
374 let done = Arc::new(AtomicBool::new(false));
375
376 let azure_config = app_config
377 .azure_configuration
378 .as_ref()
379 .ok_or_else(|| anyhow::anyhow!("Azure configuration not found"))?;
380
381 let chatbot_config = azure_config
382 .chatbot_config
383 .as_ref()
384 .ok_or_else(|| anyhow::anyhow!("Chatbot configuration not found"))?;
385
386 let api_key = chatbot_config.api_key.clone();
387 let mut url = chatbot_config.api_endpoint.clone();
388
389 url.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
391
392 let headers = build_llm_headers(&api_key)?;
393
394 let response_order_number = new_message.order_number + 1;
395
396 let response_message = models::chatbot_conversation_messages::insert(
397 conn,
398 ChatbotConversationMessage {
399 id: Uuid::new_v4(),
400 created_at: Utc::now(),
401 updated_at: Utc::now(),
402 deleted_at: None,
403 conversation_id,
404 message: None,
405 is_from_chatbot: true,
406 message_is_complete: false,
407 used_tokens: request_estimated_tokens,
408 order_number: response_order_number,
409 },
410 )
411 .await?;
412
413 let guard = RequestCancelledGuard {
415 response_message_id: response_message.id,
416 received_string: full_response_text.clone(),
417 pool: pool.clone(),
418 done: done.clone(),
419 request_estimated_tokens,
420 };
421
422 let request = REQWEST_CLIENT
423 .post(url)
424 .headers(headers)
425 .json(&chat_request)
426 .send();
427
428 let response = request.await?;
429
430 info!("Receiving chat response with {:?}", response.version());
431
432 if !response.status().is_success() {
433 let status = response.status();
434 let error_message = response.text().await?;
435 return Err(anyhow::anyhow!(
436 "Failed to send chat request. Status: {}. Error: {}",
437 status,
438 error_message
439 ));
440 }
441
442 let stream = response.bytes_stream().map_err(std::io::Error::other);
443 let reader = StreamReader::new(stream);
444 let mut lines = reader.lines();
445
446 let response_stream = async_stream::try_stream! {
447 while let Some(line) = lines.next_line().await? {
448 if !line.starts_with("data: ") {
449 continue;
450 }
451 let json_str = line.trim_start_matches("data: ");
452 let mut full_response_text = full_response_text.lock().await;
453 if json_str.trim() == "[DONE]" {
454 let full_response_as_string = full_response_text.join("");
455 let estimated_cost = estimate_tokens(&full_response_as_string);
456 info!(
457 "End of chatbot response stream. Estimated cost: {}. Response: {}",
458 estimated_cost, full_response_as_string
459 );
460 done.store(true, atomic::Ordering::Relaxed);
461 let mut conn = pool.acquire().await?;
462 models::chatbot_conversation_messages::update(
463 &mut conn,
464 response_message.id,
465 &full_response_as_string,
466 true,
467 request_estimated_tokens + estimated_cost,
468 ).await?;
469 break;
470 }
471 let response_chunk = serde_json::from_str::<ResponseChunk>(json_str).map_err(|e| {
472 anyhow::anyhow!("Failed to parse response chunk: {}", e)
473 })?;
474 for choice in &response_chunk.choices {
475 if let Some(delta) = &choice.delta {
476 if let Some(content) = &delta.content {
477 full_response_text.push(content.clone());
478 let response = ChatResponse { text: content.clone() };
479 let response_as_string = serde_json::to_string(&response)?;
480 yield Bytes::from(response_as_string);
481 yield Bytes::from("\n");
482 }
483
484 }
485 }
486 }
487
488 if !done.load(atomic::Ordering::Relaxed) {
489 Err(anyhow::anyhow!("Stream ended unexpectedly"))?;
490 }
491 };
492
493 let guarded_stream = GuardedStream::new(guard, response_stream);
496
497 Ok(Box::pin(guarded_stream))
499}