1use crate::prelude::*;
2use headless_lms_utils::ApplicationConfiguration;
3use reqwest::Response;
4use reqwest::header::HeaderMap;
5use serde::{Deserialize, Serialize};
6use tracing::{debug, error, instrument, trace, warn};
7
8pub const LLM_API_VERSION: &str = "2024-06-01";
10
11#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
13pub enum MessageRole {
14 #[serde(rename = "system")]
15 System,
16 #[serde(rename = "user")]
17 User,
18 #[serde(rename = "assistant")]
19 Assistant,
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct Message {
25 pub role: MessageRole,
26 pub content: String,
27}
28
29#[derive(Serialize, Deserialize, Debug)]
31pub struct BaseLlmRequest {
32 pub messages: Vec<Message>,
33 pub temperature: f32,
34 pub max_tokens: Option<i32>,
35}
36
37#[derive(Serialize, Deserialize, Debug)]
40pub struct AzureCompletionRequest {
41 #[serde(flatten)]
42 pub base: BaseLlmRequest,
43 pub stream: bool,
44}
45
46#[derive(Deserialize, Debug)]
48pub struct LlmCompletionResponse {
49 pub choices: Vec<LlmChoice>,
50}
51
52#[derive(Deserialize, Debug)]
53pub struct LlmChoice {
54 pub message: Message,
55}
56
57#[instrument(skip(api_key), fields(api_key_length = api_key.len()))]
59pub fn build_llm_headers(api_key: &str) -> anyhow::Result<HeaderMap> {
60 trace!("Building LLM request headers");
61 let mut headers = HeaderMap::new();
62 headers.insert(
63 "api-key",
64 api_key.parse().map_err(|_e| {
65 error!("Failed to parse API key");
66 anyhow::anyhow!("Invalid API key")
67 })?,
68 );
69 headers.insert(
70 "content-type",
71 "application/json".parse().map_err(|_e| {
72 error!("Failed to parse content-type header");
73 anyhow::anyhow!("Internal error")
74 })?,
75 );
76 trace!("Successfully built headers");
77 Ok(headers)
78}
79
80#[instrument(skip(endpoint))]
82pub fn prepare_azure_endpoint(mut endpoint: url::Url) -> url::Url {
83 trace!(
84 "Preparing Azure endpoint with API version {}",
85 LLM_API_VERSION
86 );
87 endpoint.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
88 trace!("Endpoint prepared: {}", endpoint);
89 endpoint
90}
91
92#[instrument(skip(text), fields(text_length = text.len()))]
94pub fn estimate_tokens(text: &str) -> i32 {
95 trace!("Estimating tokens for text");
96 let text_length = text.chars().fold(0, |acc, c| {
97 let mut len = c.len_utf8() as i32;
98 if len > 1 {
99 len *= 2;
101 }
102 if c.is_ascii_punctuation() {
103 len *= 2;
105 }
106 acc + len
107 });
108 let estimated_tokens = text_length / 4;
110 trace!("Estimated {} tokens for text", estimated_tokens);
111 estimated_tokens
112}
113
114#[instrument(skip(messages, endpoint, api_key), fields(
116 num_messages = messages.len(),
117 temperature,
118 max_tokens,
119 endpoint = %endpoint
120))]
121async fn make_llm_request(
122 messages: Vec<Message>,
123 temperature: f32,
124 max_tokens: Option<i32>,
125 endpoint: &url::Url,
126 api_key: &str,
127) -> anyhow::Result<LlmCompletionResponse> {
128 debug!("Preparing LLM request with {} messages", messages.len());
129 let base_request = BaseLlmRequest {
130 messages,
131 temperature,
132 max_tokens,
133 };
134
135 trace!("Base request prepared: {:?}", base_request);
136
137 let request = AzureCompletionRequest {
138 base: base_request,
139 stream: false,
140 };
141
142 let headers = build_llm_headers(api_key)?;
143 debug!("Sending request to LLM endpoint: {}", endpoint);
144
145 let response = REQWEST_CLIENT
146 .post(prepare_azure_endpoint(endpoint.clone()))
147 .headers(headers)
148 .json(&request)
149 .send()
150 .await?;
151
152 trace!("Received response from LLM");
153 process_llm_response(response).await
154}
155
156#[instrument(skip(response), fields(status = %response.status()))]
158async fn process_llm_response(response: Response) -> anyhow::Result<LlmCompletionResponse> {
159 if !response.status().is_success() {
160 let status = response.status();
161 let error_text = response.text().await?;
162 error!(
163 status = %status,
164 error = %error_text,
165 "Error calling LLM API"
166 );
167 return Err(anyhow::anyhow!(
168 "Error calling LLM API: Status: {}. Error: {}",
169 status,
170 error_text
171 ));
172 }
173
174 trace!("Processing successful LLM response");
175 let completion: LlmCompletionResponse = response.json().await?;
177 debug!(
178 "Successfully processed LLM response with {} choices",
179 completion.choices.len()
180 );
181 Ok(completion)
182}
183
184#[instrument(skip(messages, app_config), fields(
186 num_messages = messages.len(),
187 temperature,
188 max_tokens
189))]
190pub async fn make_streaming_llm_request(
191 messages: Vec<Message>,
192 temperature: f32,
193 max_tokens: Option<i32>,
194 app_config: &ApplicationConfiguration,
195) -> anyhow::Result<Response> {
196 debug!(
197 "Preparing streaming LLM request with {} messages",
198 messages.len()
199 );
200 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
201 error!("Azure configuration missing");
202 anyhow::anyhow!("Azure configuration is missing from the application configuration")
203 })?;
204
205 let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
206 error!("Chatbot configuration missing");
207 anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
208 })?;
209
210 let base_request = BaseLlmRequest {
211 messages,
212 temperature,
213 max_tokens,
214 };
215
216 trace!("Base request prepared: {:?}", base_request);
217
218 let request = AzureCompletionRequest {
219 base: base_request,
220 stream: true,
221 };
222
223 let headers = build_llm_headers(&chatbot_config.api_key)?;
224 debug!(
225 "Sending streaming request to LLM endpoint: {}",
226 chatbot_config.api_endpoint
227 );
228
229 dbg!(&request, &headers, &chatbot_config.api_endpoint);
230
231 let response = REQWEST_CLIENT
232 .post(prepare_azure_endpoint(chatbot_config.api_endpoint.clone()))
233 .headers(headers)
234 .json(&request)
235 .send()
236 .await?;
237
238 if !response.status().is_success() {
239 let status = response.status();
240 let error_text = response.text().await?;
241 error!(
242 status = %status,
243 error = %error_text,
244 "Error calling streaming LLM API"
245 );
246 return Err(anyhow::anyhow!(
247 "Error calling LLM API: Status: {}. Error: {}",
248 status,
249 error_text
250 ));
251 }
252
253 debug!("Successfully initiated streaming response");
254 Ok(response)
255}
256
257#[instrument(skip(messages, app_config), fields(
259 num_messages = messages.len(),
260 temperature,
261 max_tokens
262))]
263pub async fn make_blocking_llm_request(
264 messages: Vec<Message>,
265 temperature: f32,
266 max_tokens: Option<i32>,
267 app_config: &ApplicationConfiguration,
268) -> anyhow::Result<LlmCompletionResponse> {
269 debug!(
270 "Preparing blocking LLM request with {} messages",
271 messages.len()
272 );
273 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
274 error!("Azure configuration missing");
275 anyhow::anyhow!("Azure configuration is missing from the application configuration")
276 })?;
277
278 let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
279 error!("Chatbot configuration missing");
280 anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
281 })?;
282
283 trace!(
284 "Making LLM request to endpoint: {}",
285 chatbot_config.api_endpoint
286 );
287 make_llm_request(
288 messages,
289 temperature,
290 max_tokens,
291 &chatbot_config.api_endpoint,
292 &chatbot_config.api_key,
293 )
294 .await
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_estimate_tokens() {
303 assert_eq!(estimate_tokens("Hello, world!"), 3);
305 assert_eq!(estimate_tokens(""), 0);
306 assert_eq!(
308 estimate_tokens("This is a longer sentence with several words."),
309 11
310 );
311 assert_eq!(estimate_tokens("Hyvää päivää!"), 7);
313 assert_eq!(estimate_tokens("トークンは楽しい"), 12);
315 assert_eq!(
317 estimate_tokens("🙂🙃😀😃😄😁😆😅😂🤣😊😇🙂🙃😀😃😄😁😆😅😂🤣😊😇"),
318 48
319 );
320 assert_eq!(estimate_tokens("ฉันใช้โทเค็นทุกวัน"), 27);
322 assert_eq!(estimate_tokens("Жетони роблять мене щасливим"), 25);
324 }
325}