1use crate::{azure_chatbot::LLMRequest, prelude::*};
2use headless_lms_utils::ApplicationConfiguration;
3use reqwest::Response;
4use reqwest::header::HeaderMap;
5use serde::{Deserialize, Serialize};
6use tracing::{debug, error, instrument, trace, warn};
7
8pub const LLM_API_VERSION: &str = "2024-06-01";
10
11#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
13pub enum MessageRole {
14 #[serde(rename = "system")]
15 System,
16 #[serde(rename = "user")]
17 User,
18 #[serde(rename = "assistant")]
19 Assistant,
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct APIMessage {
25 pub role: MessageRole,
26 pub content: String,
27}
28
29#[derive(Serialize, Deserialize, Debug)]
32pub struct AzureCompletionRequest {
33 #[serde(flatten)]
34 pub base: LLMRequest,
35 pub stream: bool,
36}
37
38#[derive(Deserialize, Debug)]
40pub struct LLMCompletionResponse {
41 pub choices: Vec<LLMChoice>,
42}
43
44#[derive(Deserialize, Debug)]
45pub struct LLMChoice {
46 pub message: APIMessage,
47}
48
49#[instrument(skip(api_key), fields(api_key_length = api_key.len()))]
51pub fn build_llm_headers(api_key: &str) -> anyhow::Result<HeaderMap> {
52 trace!("Building LLM request headers");
53 let mut headers = HeaderMap::new();
54 headers.insert(
55 "api-key",
56 api_key.parse().map_err(|_e| {
57 error!("Failed to parse API key");
58 anyhow::anyhow!("Invalid API key")
59 })?,
60 );
61 headers.insert(
62 "content-type",
63 "application/json".parse().map_err(|_e| {
64 error!("Failed to parse content-type header");
65 anyhow::anyhow!("Internal error")
66 })?,
67 );
68 trace!("Successfully built headers");
69 Ok(headers)
70}
71
72#[instrument(skip(endpoint))]
74pub fn prepare_azure_endpoint(mut endpoint: url::Url) -> url::Url {
75 trace!(
76 "Preparing Azure endpoint with API version {}",
77 LLM_API_VERSION
78 );
79 endpoint.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
81 trace!("Endpoint prepared: {}", endpoint);
82 endpoint
83}
84
85#[instrument(skip(text), fields(text_length = text.len()))]
87pub fn estimate_tokens(text: &str) -> i32 {
88 trace!("Estimating tokens for text");
89 let text_length = text.chars().fold(0, |acc, c| {
90 let mut len = c.len_utf8() as i32;
91 if len > 1 {
92 len *= 2;
94 }
95 if c.is_ascii_punctuation() {
96 len *= 2;
98 }
99 acc + len
100 });
101 let estimated_tokens = text_length / 4;
103 trace!("Estimated {} tokens for text", estimated_tokens);
104 estimated_tokens
105}
106
107#[instrument(skip(chat_request, endpoint, api_key), fields(
109 num_messages = chat_request.messages.len(),
110 temperature,
111 max_tokens,
112 endpoint = %endpoint
113))]
114async fn make_llm_request(
115 chat_request: LLMRequest,
116 endpoint: &url::Url,
117 api_key: &str,
118) -> anyhow::Result<LLMCompletionResponse> {
119 debug!(
120 "Preparing LLM request with {} messages",
121 chat_request.messages.len()
122 );
123
124 trace!("Base request: {:?}", chat_request);
125
126 let request = AzureCompletionRequest {
127 base: chat_request,
128 stream: false,
129 };
130
131 let headers = build_llm_headers(api_key)?;
132 debug!("Sending request to LLM endpoint: {}", endpoint);
133
134 let response = REQWEST_CLIENT
135 .post(prepare_azure_endpoint(endpoint.clone()))
136 .headers(headers)
137 .json(&request)
138 .send()
139 .await?;
140
141 trace!("Received response from LLM");
142 process_llm_response(response).await
143}
144
145#[instrument(skip(response), fields(status = %response.status()))]
147async fn process_llm_response(response: Response) -> anyhow::Result<LLMCompletionResponse> {
148 if !response.status().is_success() {
149 let status = response.status();
150 let error_text = response.text().await?;
151 error!(
152 status = %status,
153 error = %error_text,
154 "Error calling LLM API"
155 );
156 return Err(anyhow::anyhow!(
157 "Error calling LLM API: Status: {}. Error: {}",
158 status,
159 error_text
160 ));
161 }
162
163 trace!("Processing successful LLM response");
164 let completion: LLMCompletionResponse = response.json().await?;
166 debug!(
167 "Successfully processed LLM response with {} choices",
168 completion.choices.len()
169 );
170 Ok(completion)
171}
172
173#[instrument(skip(chat_request, app_config), fields(
175 num_messages = chat_request.messages.len(),
176 temperature,
177 max_tokens
178))]
179pub async fn make_streaming_llm_request(
180 chat_request: LLMRequest,
181 model_deployment_name: &str,
182 app_config: &ApplicationConfiguration,
183) -> anyhow::Result<Response> {
184 debug!(
185 "Preparing streaming LLM request with {} messages",
186 chat_request.messages.len()
187 );
188 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
189 error!("Azure configuration missing");
190 anyhow::anyhow!("Azure configuration is missing from the application configuration")
191 })?;
192
193 let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
194 error!("Chatbot configuration missing");
195 anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
196 })?;
197
198 trace!("Base request: {:?}", chat_request);
199
200 let request = AzureCompletionRequest {
201 base: chat_request,
202 stream: true,
203 };
204
205 let headers = build_llm_headers(&chatbot_config.api_key)?;
206 let api_endpoint = chatbot_config
207 .api_endpoint
208 .join(&(model_deployment_name.to_owned() + "/chat/completions"))?;
209 debug!(
210 "Sending streaming request to LLM endpoint: {}",
211 api_endpoint
212 );
213
214 let response = REQWEST_CLIENT
215 .post(prepare_azure_endpoint(api_endpoint.clone()))
216 .headers(headers)
217 .json(&request)
218 .send()
219 .await?;
220
221 if !response.status().is_success() {
222 let status = response.status();
223 let error_text = response.text().await?;
224 error!(
225 status = %status,
226 error = %error_text,
227 "Error calling streaming LLM API"
228 );
229 return Err(anyhow::anyhow!(
230 "Error calling LLM API: Status: {}. Error: {}",
231 status,
232 error_text
233 ));
234 }
235
236 debug!("Successfully initiated streaming response");
237 Ok(response)
238}
239
240#[instrument(skip(chat_request, app_config), fields(
242 num_messages = chat_request.messages.len(),
243 temperature,
244 max_tokens
245))]
246pub async fn make_blocking_llm_request(
247 chat_request: LLMRequest,
248 app_config: &ApplicationConfiguration,
249) -> anyhow::Result<LLMCompletionResponse> {
250 debug!(
251 "Preparing blocking LLM request with {} messages",
252 chat_request.messages.len()
253 );
254 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
255 error!("Azure configuration missing");
256 anyhow::anyhow!("Azure configuration is missing from the application configuration")
257 })?;
258
259 let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
260 error!("Chatbot configuration missing");
261 anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
262 })?;
263
264 let api_endpoint = chatbot_config
265 .api_endpoint
266 .join("gpt-4o/chat/completions")?;
267
268 trace!("Making LLM request to endpoint: {}", api_endpoint);
269 make_llm_request(chat_request, &api_endpoint, &chatbot_config.api_key).await
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_estimate_tokens() {
278 assert_eq!(estimate_tokens("Hello, world!"), 3);
280 assert_eq!(estimate_tokens(""), 0);
281 assert_eq!(
283 estimate_tokens("This is a longer sentence with several words."),
284 11
285 );
286 assert_eq!(estimate_tokens("Hyvää päivää!"), 7);
288 assert_eq!(estimate_tokens("トークンは楽しい"), 12);
290 assert_eq!(
292 estimate_tokens("🙂🙃😀😃😄😁😆😅😂🤣😊😇🙂🙃😀😃😄😁😆😅😂🤣😊😇"),
293 48
294 );
295 assert_eq!(estimate_tokens("ฉันใช้โทเค็นทุกวัน"), 27);
297 assert_eq!(estimate_tokens("Жетони роблять мене щасливим"), 25);
299 }
300}