1use crate::{
2 azure_chatbot::{LLMRequest, ToolCallType},
3 chatbot_error::ChatbotResult,
4 prelude::*,
5};
6use core::default::Default;
7use headless_lms_models::{
8 chatbot_conversation_message_tool_calls::ChatbotConversationMessageToolCall,
9 chatbot_conversation_message_tool_outputs::ChatbotConversationMessageToolOutput,
10 chatbot_conversation_messages::{ChatbotConversationMessage, MessageRole},
11};
12use headless_lms_utils::ApplicationConfiguration;
13use reqwest::Response;
14use reqwest::header::HeaderMap;
15use serde::{Deserialize, Serialize};
16use tracing::{debug, error, instrument, trace, warn};
17
18pub const LLM_API_VERSION: &str = "2024-06-01";
20
21#[derive(Serialize, Deserialize, Debug, Clone)]
23pub struct APIMessage {
24 pub role: MessageRole,
25 #[serde(flatten)]
26 pub fields: APIMessageKind,
27}
28
29impl APIMessage {
30 pub fn to_chatbot_conversation_message(
35 &self,
36 conversation_id: Uuid,
37 order_number: i32,
38 ) -> ChatbotResult<ChatbotConversationMessage> {
39 let res = match self.fields.clone() {
40 APIMessageKind::Text(msg) => ChatbotConversationMessage {
41 message_role: self.role,
42 conversation_id,
43 order_number,
44 message_is_complete: true,
45 used_tokens: estimate_tokens(&msg.content),
46 message: Some(msg.content),
47 ..Default::default()
48 },
49 APIMessageKind::ToolCall(msg) => {
50 let tool_call_fields = msg
51 .tool_calls
52 .iter()
53 .map(|x| ChatbotConversationMessageToolCall::try_from(x.to_owned()))
54 .collect::<ChatbotResult<Vec<_>>>()?;
55 let estimated_tokens: i32 = msg
56 .tool_calls
57 .iter()
58 .map(|x| estimate_tokens(&x.function.arguments))
59 .sum();
60 ChatbotConversationMessage {
61 message_role: self.role,
62 conversation_id,
63 order_number,
64 message_is_complete: true,
65 message: None,
66 tool_call_fields,
67 used_tokens: estimated_tokens,
68 ..Default::default()
69 }
70 }
71 APIMessageKind::ToolResponse(msg) => ChatbotConversationMessage {
72 message_role: self.role,
73 conversation_id,
74 order_number,
75 message_is_complete: true,
76 message: None,
77 used_tokens: 0,
78 tool_output: Some(ChatbotConversationMessageToolOutput::from(msg)),
79 ..Default::default()
80 },
81 };
82 Result::Ok(res)
83 }
84}
85
86impl TryFrom<ChatbotConversationMessage> for APIMessage {
87 type Error = ChatbotError;
88 fn try_from(message: ChatbotConversationMessage) -> ChatbotResult<Self> {
89 let res = match message.message_role {
90 MessageRole::Assistant => {
91 if !message.tool_call_fields.is_empty() {
92 APIMessage {
93 role: message.message_role,
94 fields: APIMessageKind::ToolCall(APIMessageToolCall {
95 tool_calls: message
96 .tool_call_fields
97 .iter()
98 .map(|f| APIToolCall::from(f.clone()))
99 .collect(),
100 }),
101 }
102 } else if let Some(msg) = message.message {
103 APIMessage {
104 role: message.message_role,
105 fields: APIMessageKind::Text(APIMessageText { content: msg }),
106 }
107 } else {
108 return Err(ChatbotError::new(
109 ChatbotErrorType::InvalidMessageShape,
110 "A 'role: assistant' type ChatbotConversationMessage must have either tool_call_fields or a text message.",
111 None,
112 ));
113 }
114 }
115 MessageRole::Tool => {
116 if let Some(tool_output) = message.tool_output {
117 APIMessage {
118 role: message.message_role,
119 fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
120 tool_call_id: tool_output.tool_call_id,
121 name: tool_output.tool_name,
122 content: tool_output.tool_output,
123 }),
124 }
125 } else {
126 return Err(ChatbotError::new(
127 ChatbotErrorType::InvalidMessageShape,
128 "A 'role: tool' type ChatbotConversationMessage must have field tool_output.",
129 None,
130 ));
131 }
132 }
133 MessageRole::User => APIMessage {
134 role: message.message_role,
135 fields: APIMessageKind::Text(APIMessageText {
136 content: message.message.unwrap_or_default(),
137 }),
138 },
139 MessageRole::System => {
140 return Err(ChatbotError::new(
141 ChatbotErrorType::InvalidMessageShape,
142 "A 'role: system' type ChatbotConversationMessage cannot be saved into the database.",
143 None,
144 ));
145 }
146 };
147 Result::Ok(res)
148 }
149}
150
151#[derive(Serialize, Deserialize, Debug, Clone)]
152#[serde(untagged)]
153pub enum APIMessageKind {
154 Text(APIMessageText),
155 ToolCall(APIMessageToolCall),
156 ToolResponse(APIMessageToolResponse),
157}
158
159#[derive(Serialize, Deserialize, Debug, Clone)]
161pub struct APIMessageText {
162 pub content: String,
163}
164
165#[derive(Serialize, Deserialize, Debug, Clone)]
168pub struct APIMessageToolCall {
169 pub tool_calls: Vec<APIToolCall>,
170}
171
172#[derive(Serialize, Deserialize, Debug, Clone)]
174pub struct APIMessageToolResponse {
175 pub tool_call_id: String,
176 pub name: String,
177 pub content: String,
178}
179
180impl From<ChatbotConversationMessageToolOutput> for APIMessageToolResponse {
181 fn from(value: ChatbotConversationMessageToolOutput) -> Self {
182 APIMessageToolResponse {
183 tool_call_id: value.tool_call_id,
184 name: value.tool_name,
185 content: value.tool_output,
186 }
187 }
188}
189
190impl From<APIMessageToolResponse> for ChatbotConversationMessageToolOutput {
191 fn from(value: APIMessageToolResponse) -> Self {
192 ChatbotConversationMessageToolOutput {
193 tool_name: value.name,
194 tool_output: value.content,
195 tool_call_id: value.tool_call_id,
196 ..Default::default()
197 }
198 }
199}
200
201#[derive(Serialize, Deserialize, Debug, Clone)]
203pub struct APIToolCall {
204 pub function: APITool,
205 pub id: String,
206 #[serde(rename = "type")]
207 pub tool_type: ToolCallType,
208}
209
210impl From<ChatbotConversationMessageToolCall> for APIToolCall {
211 fn from(value: ChatbotConversationMessageToolCall) -> Self {
212 APIToolCall {
213 function: APITool {
214 arguments: value.tool_arguments.to_string(),
215 name: value.tool_name,
216 },
217 id: value.tool_call_id,
218 tool_type: ToolCallType::Function,
219 }
220 }
221}
222
223impl TryFrom<APIToolCall> for ChatbotConversationMessageToolCall {
224 type Error = ChatbotError;
225 fn try_from(value: APIToolCall) -> ChatbotResult<Self> {
226 Ok(ChatbotConversationMessageToolCall {
227 tool_name: value.function.name,
228 tool_arguments: serde_json::from_str(&value.function.arguments)?,
229 tool_call_id: value.id,
230 ..Default::default()
231 })
232 }
233}
234
235#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
236pub struct APITool {
237 pub arguments: String,
238 pub name: String,
239}
240
241#[derive(Serialize, Deserialize, Debug)]
244pub struct AzureCompletionRequest {
245 #[serde(flatten)]
246 pub base: LLMRequest,
247 pub stream: bool,
248}
249
250#[derive(Deserialize, Debug)]
252pub struct LLMCompletionResponse {
253 pub choices: Vec<LLMChoice>,
254}
255
256#[derive(Deserialize, Debug)]
257pub struct LLMChoice {
258 pub message: APIMessage,
259}
260
261#[instrument(skip(api_key), fields(api_key_length = api_key.len()))]
263pub fn build_llm_headers(api_key: &str) -> anyhow::Result<HeaderMap> {
264 trace!("Building LLM request headers");
265 let mut headers = HeaderMap::new();
266 headers.insert(
267 "api-key",
268 api_key.parse().map_err(|_e| {
269 error!("Failed to parse API key");
270 anyhow::anyhow!("Invalid API key")
271 })?,
272 );
273 headers.insert(
274 "content-type",
275 "application/json".parse().map_err(|_e| {
276 error!("Failed to parse content-type header");
277 anyhow::anyhow!("Internal error")
278 })?,
279 );
280 trace!("Successfully built headers");
281 Ok(headers)
282}
283
284#[instrument(skip(endpoint))]
286pub fn prepare_azure_endpoint(mut endpoint: url::Url) -> url::Url {
287 trace!(
288 "Preparing Azure endpoint with API version {}",
289 LLM_API_VERSION
290 );
291 endpoint.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
293 trace!("Endpoint prepared: {}", endpoint);
294 endpoint
295}
296
297#[instrument(skip(text), fields(text_length = text.len()))]
299pub fn estimate_tokens(text: &str) -> i32 {
300 trace!("Estimating tokens for text");
301 let text_length = text.chars().fold(0, |acc, c| {
302 let mut len = c.len_utf8() as i32;
303 if len > 1 {
304 len *= 2;
306 }
307 if c.is_ascii_punctuation() {
308 len *= 2;
310 }
311 acc + len
312 });
313 let estimated_tokens = text_length / 4;
315 trace!("Estimated {} tokens for text", estimated_tokens);
316 estimated_tokens
317}
318
319#[instrument(skip(chat_request, endpoint, api_key), fields(
321 num_messages = chat_request.messages.len(),
322 temperature,
323 max_tokens,
324 endpoint = %endpoint
325))]
326async fn make_llm_request(
327 chat_request: LLMRequest,
328 endpoint: &url::Url,
329 api_key: &str,
330) -> anyhow::Result<LLMCompletionResponse> {
331 debug!(
332 "Preparing LLM request with {} messages",
333 chat_request.messages.len()
334 );
335
336 trace!("Base request: {:?}", chat_request);
337
338 let request = AzureCompletionRequest {
339 base: chat_request,
340 stream: false,
341 };
342
343 let headers = build_llm_headers(api_key)?;
344 debug!("Sending request to LLM endpoint: {}", endpoint);
345
346 let response = REQWEST_CLIENT
347 .post(prepare_azure_endpoint(endpoint.clone()))
348 .headers(headers)
349 .json(&request)
350 .send()
351 .await?;
352
353 trace!("Received response from LLM");
354 process_llm_response(response).await
355}
356
357#[instrument(skip(response), fields(status = %response.status()))]
359async fn process_llm_response(response: Response) -> anyhow::Result<LLMCompletionResponse> {
360 if !response.status().is_success() {
361 let status = response.status();
362 let error_text = response.text().await?;
363 error!(
364 status = %status,
365 error = %error_text,
366 "Error calling LLM API"
367 );
368 return Err(anyhow::anyhow!(
369 "Error calling LLM API: Status: {}. Error: {}",
370 status,
371 error_text
372 ));
373 }
374
375 trace!("Processing successful LLM response");
376 let completion: LLMCompletionResponse = response.json().await?;
378 debug!(
379 "Successfully processed LLM response with {} choices",
380 completion.choices.len()
381 );
382 Ok(completion)
383}
384
385#[instrument(skip(chat_request, app_config), fields(
387 num_messages = chat_request.messages.len(),
388 temperature,
389 max_tokens
390))]
391pub async fn make_streaming_llm_request(
392 chat_request: LLMRequest,
393 model_deployment_name: &str,
394 app_config: &ApplicationConfiguration,
395) -> anyhow::Result<Response> {
396 debug!(
397 "Preparing streaming LLM request with {} messages",
398 chat_request.messages.len()
399 );
400 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
401 error!("Azure configuration missing");
402 anyhow::anyhow!("Azure configuration is missing from the application configuration")
403 })?;
404
405 let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
406 error!("Chatbot configuration missing");
407 anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
408 })?;
409
410 trace!("Base request: {:?}", chat_request);
411
412 let request = AzureCompletionRequest {
413 base: chat_request,
414 stream: true,
415 };
416
417 let headers = build_llm_headers(&chatbot_config.api_key)?;
418 let api_endpoint = chatbot_config
419 .api_endpoint
420 .join(&(model_deployment_name.to_owned() + "/chat/completions"))?;
421 debug!(
422 "Sending streaming request to LLM endpoint: {}",
423 api_endpoint
424 );
425
426 let response = REQWEST_CLIENT
427 .post(prepare_azure_endpoint(api_endpoint.clone()))
428 .headers(headers)
429 .json(&request)
430 .send()
431 .await?;
432
433 if !response.status().is_success() {
434 let status = response.status();
435 let error_text = response.text().await?;
436 error!(
437 status = %status,
438 error = %error_text,
439 "Error calling streaming LLM API"
440 );
441 return Err(anyhow::anyhow!(
442 "Error calling LLM API: Status: {}. Error: {}",
443 status,
444 error_text
445 ));
446 }
447
448 debug!("Successfully initiated streaming response");
449 Ok(response)
450}
451
452#[instrument(skip(chat_request, app_config), fields(
454 num_messages = chat_request.messages.len(),
455 temperature,
456 max_tokens
457))]
458pub async fn make_blocking_llm_request(
459 chat_request: LLMRequest,
460 app_config: &ApplicationConfiguration,
461) -> anyhow::Result<LLMCompletionResponse> {
462 debug!(
463 "Preparing blocking LLM request with {} messages",
464 chat_request.messages.len()
465 );
466 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
467 error!("Azure configuration missing");
468 anyhow::anyhow!("Azure configuration is missing from the application configuration")
469 })?;
470
471 let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
472 error!("Chatbot configuration missing");
473 anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
474 })?;
475
476 let api_endpoint = chatbot_config
477 .api_endpoint
478 .join("gpt-4o/chat/completions")?;
479
480 trace!("Making LLM request to endpoint: {}", api_endpoint);
481 make_llm_request(chat_request, &api_endpoint, &chatbot_config.api_key).await
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_estimate_tokens() {
490 assert_eq!(estimate_tokens("Hello, world!"), 3);
492 assert_eq!(estimate_tokens(""), 0);
493 assert_eq!(
495 estimate_tokens("This is a longer sentence with several words."),
496 11
497 );
498 assert_eq!(estimate_tokens("Hyvää päivää!"), 7);
500 assert_eq!(estimate_tokens("トークンは楽しい"), 12);
502 assert_eq!(
504 estimate_tokens("🙂🙃😀😃😄😁😆😅😂🤣😊😇🙂🙃😀😃😄😁😆😅😂🤣😊😇"),
505 48
506 );
507 assert_eq!(estimate_tokens("ฉันใช้โทเค็นทุกวัน"), 27);
509 assert_eq!(estimate_tokens("Жетони роблять мене щасливим"), 25);
511 }
512}