1use crate::prelude::*;
2use headless_lms_utils::{ApplicationConfiguration, http::REQWEST_CLIENT};
3
4const API_VERSION: &str = "2024-07-01";
5
6#[derive(Serialize, Deserialize)]
7#[serde(rename_all = "camelCase")]
8pub struct NewIndex {
9 name: String,
10 fields: Vec<Field>,
11 scoring_profiles: Vec<ScoringProfile>,
12 default_scoring_profile: Option<String>,
13 suggesters: Vec<Suggester>,
14 analyzers: Vec<Analyzer>,
15 tokenizers: Vec<serde_json::Value>,
16 token_filters: Vec<serde_json::Value>,
17 char_filters: Vec<serde_json::Value>,
18 cors_options: CorsOptions,
19 encryption_key: Option<EncryptionKey>,
20 similarity: Similarity,
21 semantic: Semantic,
22 vector_search: VectorSearch,
23}
24
25#[derive(Serialize, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct Analyzer {
28 name: String,
29 #[serde(rename = "@odata.type")]
30 odata_type: String,
31 char_filters: Vec<String>,
32 tokenizer: String,
33}
34
35#[derive(Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct CorsOptions {
38 allowed_origins: Vec<String>,
39 max_age_in_seconds: i64,
40}
41
42#[derive(Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct EncryptionKey {
45 key_vault_key_name: String,
46 key_vault_key_version: String,
47 key_vault_uri: String,
48 access_credentials: AccessCredentials,
49}
50
51#[derive(Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53pub struct AccessCredentials {
54 application_id: String,
55 application_secret: String,
56}
57
58#[derive(Serialize, Deserialize)]
59#[serde(rename_all = "camelCase")]
60pub struct Field {
61 name: String,
62 #[serde(rename = "type")]
63 field_type: String,
64 key: Option<bool>,
65 searchable: Option<bool>,
66 filterable: Option<bool>,
67 sortable: Option<bool>,
68 facetable: Option<bool>,
69 retrievable: Option<bool>,
70 index_analyzer: Option<String>,
71 search_analyzer: Option<String>,
72 analyzer: Option<String>,
73 synonym_maps: Option<Vec<String>>,
74 dimensions: Option<i64>,
75 vector_search_profile: Option<String>,
76 stored: Option<bool>,
77 vector_encoding: Option<serde_json::Value>,
78}
79
80#[derive(Serialize, Deserialize)]
81pub struct ScoringProfile {
82 name: String,
83 text: Text,
84 functions: Vec<Function>,
85}
86
87#[derive(Serialize, Deserialize)]
88#[serde(rename_all = "camelCase")]
89pub struct Function {
90 #[serde(rename = "type")]
91 function_type: String,
92 boost: i64,
93 field_name: String,
94 interpolation: String,
95 distance: Distance,
96}
97
98#[derive(Serialize, Deserialize)]
99#[serde(rename_all = "camelCase")]
100pub struct Distance {
101 reference_point_parameter: String,
102 boosting_distance: i64,
103}
104
105#[derive(Serialize, Deserialize)]
106pub struct Text {
107 weights: Weights,
108}
109
110#[derive(Serialize, Deserialize)]
111#[serde(rename_all = "camelCase")]
112pub struct Weights {
113 hotel_name: i64,
114}
115
116#[derive(Serialize, Deserialize)]
117#[serde(rename_all = "camelCase")]
118pub struct Semantic {
119 default_configuration: String,
120 configurations: Vec<SemanticConfiguration>,
121}
122
123#[derive(Serialize, Deserialize)]
124#[serde(rename_all = "camelCase")]
125pub struct SemanticConfiguration {
126 name: String,
127 prioritized_fields: SemanticConfigurationPrioritizedFields,
128}
129
130#[derive(Serialize, Deserialize)]
131#[serde(rename_all = "camelCase")]
132pub struct SemanticConfigurationPrioritizedFields {
133 title_field: FieldDescriptor,
134 prioritized_content_fields: Vec<FieldDescriptor>,
135 prioritized_keywords_fields: Vec<FieldDescriptor>,
136}
137
138#[derive(Serialize, Deserialize)]
139#[serde(rename_all = "camelCase")]
140pub struct FieldDescriptor {
141 field_name: String,
142}
143
144#[derive(Serialize, Deserialize)]
145pub struct Similarity {
146 #[serde(rename = "@odata.type")]
147 odata_type: String,
148 b: Option<f64>,
149 k1: Option<f64>,
150}
151
152#[derive(Serialize, Deserialize)]
153#[serde(rename_all = "camelCase")]
154pub struct Suggester {
155 name: String,
156 search_mode: String,
157 source_fields: Vec<String>,
158}
159
160#[derive(Serialize, Deserialize)]
161pub struct VectorSearch {
162 profiles: Vec<Profile>,
163 algorithms: Vec<Algorithm>,
164 compressions: Vec<Compression>,
165 vectorizers: Vec<Vectorizer>,
166}
167
168#[derive(Serialize, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct Vectorizer {
171 name: String,
172 kind: String,
173 #[serde(rename = "azureOpenAIParameters")]
174 azure_open_ai_parameters: AzureOpenAiParameters,
175 custom_web_api_parameters: Option<serde_json::Value>,
176}
177
178#[derive(Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct AzureOpenAiParameters {
181 resource_uri: String,
182 deployment_id: String,
183 api_key: String,
184 model_name: String,
185 auth_identity: Option<serde_json::Value>,
186}
187
188#[derive(Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct Algorithm {
191 name: String,
192 kind: String,
193 hnsw_parameters: Option<HnswParameters>,
194 exhaustive_knn_parameters: Option<ExhaustiveKnnParameters>,
195}
196
197#[derive(Serialize, Deserialize)]
198pub struct ExhaustiveKnnParameters {
199 metric: String,
200}
201
202#[derive(Serialize, Deserialize)]
203#[serde(rename_all = "camelCase")]
204pub struct HnswParameters {
205 m: i64,
206 metric: String,
207 ef_construction: i64,
208 ef_search: i64,
209}
210
211#[derive(Serialize, Deserialize)]
212#[serde(rename_all = "camelCase")]
213pub struct Compression {
214 name: String,
215 kind: String,
216 scalar_quantization_parameters: Option<ScalarQuantizationParameters>,
217 rerank_with_original_vectors: bool,
218 default_oversampling: i64,
219}
220
221#[derive(Serialize, Deserialize)]
222#[serde(rename_all = "camelCase")]
223pub struct ScalarQuantizationParameters {
224 quantized_data_type: String,
225}
226
227#[derive(Serialize, Deserialize)]
228pub struct Profile {
229 name: String,
230 algorithm: String,
231 compression: Option<String>,
232 vectorizer: Option<String>,
233}
234
235pub async fn does_search_index_exist(
236 index_name: &str,
237 app_config: &ApplicationConfiguration,
238) -> anyhow::Result<bool> {
239 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
240 anyhow::anyhow!("Azure configuration is missing from the application configuration")
241 })?;
242
243 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
244 anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
245 })?;
246
247 let mut url = search_config.search_endpoint.clone();
248 url.set_path(&format!("indexes('{}')", index_name));
249 url.set_query(Some(&format!("api-version={}", API_VERSION)));
250
251 let response = REQWEST_CLIENT
252 .get(url)
253 .header("Content-Type", "application/json")
254 .header("api-key", search_config.search_api_key.clone())
255 .send()
256 .await?;
257
258 if response.status().is_success() {
259 Ok(true)
260 } else if response.status() == 404 {
261 Ok(false)
262 } else {
263 let status = response.status();
264 let error_text = response.text().await?;
265 Err(anyhow::anyhow!(
266 "Error checking if index exists. Status: {}. Error: {}",
267 status,
268 error_text
269 ))
270 }
271}
272
273pub async fn create_search_index(
274 index_name: String,
275 app_config: &ApplicationConfiguration,
276) -> anyhow::Result<()> {
277 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
278 anyhow::anyhow!("Azure configuration is missing from the application configuration")
279 })?;
280
281 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
282 anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
283 })?;
284
285 let fields = vec![
286 Field {
287 name: "chunk_id".to_string(),
288 field_type: "Edm.String".to_string(),
289 key: Some(true),
290 searchable: Some(true),
291 filterable: Some(true),
292 retrievable: Some(true),
293 stored: Some(true),
294 sortable: Some(true),
295 facetable: Some(true),
296 analyzer: Some("keyword".to_string()),
297 index_analyzer: None,
298 search_analyzer: None,
299 synonym_maps: Some(vec![]),
300 dimensions: None,
301 vector_search_profile: None,
302 vector_encoding: None,
303 },
304 Field {
305 name: "language".to_string(),
306 field_type: "Edm.String".to_string(),
307 key: Some(false),
308 searchable: Some(true),
309 filterable: Some(true),
310 retrievable: Some(true),
311 stored: Some(true),
312 sortable: Some(false),
313 facetable: Some(false),
314 analyzer: None,
315 index_analyzer: None,
316 search_analyzer: None,
317 synonym_maps: Some(vec![]),
318 dimensions: None,
319 vector_search_profile: None,
320 vector_encoding: None,
321 },
322 Field {
323 name: "parent_id".to_string(),
324 field_type: "Edm.String".to_string(),
325 key: Some(false),
326 searchable: Some(true),
327 filterable: Some(true),
328 retrievable: Some(true),
329 stored: Some(true),
330 sortable: Some(true),
331 facetable: Some(true),
332 analyzer: None,
333 index_analyzer: None,
334 search_analyzer: None,
335 synonym_maps: Some(vec![]),
336 dimensions: None,
337 vector_search_profile: None,
338 vector_encoding: None,
339 },
340 Field {
341 name: "chunk".to_string(),
342 field_type: "Edm.String".to_string(),
343 key: Some(false),
344 searchable: Some(true),
345 filterable: Some(false),
346 retrievable: Some(true),
347 stored: Some(true),
348 sortable: Some(false),
349 facetable: Some(false),
350 analyzer: None,
351 index_analyzer: None,
352 search_analyzer: None,
353 synonym_maps: Some(vec![]),
354 dimensions: None,
355 vector_search_profile: None,
356 vector_encoding: None,
357 },
358 Field {
359 name: "title".to_string(),
360 field_type: "Edm.String".to_string(),
361 key: Some(false),
362 searchable: Some(true),
363 filterable: Some(true),
364 retrievable: Some(true),
365 stored: Some(true),
366 sortable: Some(false),
367 facetable: Some(false),
368 analyzer: None,
369 index_analyzer: None,
370 search_analyzer: None,
371 synonym_maps: Some(vec![]),
372 dimensions: None,
373 vector_search_profile: None,
374 vector_encoding: None,
375 },
376 Field {
377 name: "url".to_string(),
378 field_type: "Edm.String".to_string(),
379 key: Some(false),
380 searchable: Some(false),
381 filterable: Some(true),
382 retrievable: Some(true),
383 stored: Some(true),
384 sortable: Some(false),
385 facetable: Some(false),
386 analyzer: None,
387 index_analyzer: None,
388 search_analyzer: None,
389 synonym_maps: Some(vec![]),
390 dimensions: None,
391 vector_search_profile: None,
392 vector_encoding: None,
393 },
394 Field {
395 name: "course_id".to_string(),
396 field_type: "Edm.String".to_string(),
397 key: Some(false),
398 searchable: Some(false),
399 filterable: Some(true),
400 retrievable: Some(true),
401 stored: Some(true),
402 sortable: Some(false),
403 facetable: Some(false),
404 analyzer: None,
405 index_analyzer: None,
406 search_analyzer: None,
407 synonym_maps: Some(vec![]),
408 dimensions: None,
409 vector_search_profile: None,
410 vector_encoding: None,
411 },
412 Field {
413 name: "text_vector".to_string(),
414 field_type: "Collection(Edm.Single)".to_string(),
415 key: Some(false),
416 searchable: Some(true),
417 filterable: Some(false),
418 retrievable: Some(true),
419 stored: Some(true),
420 sortable: Some(false),
421 facetable: Some(false),
422 analyzer: None,
423 index_analyzer: None,
424 search_analyzer: None,
425 synonym_maps: Some(vec![]),
426 dimensions: Some(1536),
427 vector_search_profile: Some(format!("{}-azureOpenAi-text-profile", index_name)),
428 vector_encoding: None,
429 },
430 ];
431
432 let index = NewIndex {
433 name: index_name.clone(),
434 fields,
435 scoring_profiles: vec![],
436 default_scoring_profile: None,
437 suggesters: vec![],
438 analyzers: vec![],
439 tokenizers: vec![],
440 token_filters: vec![],
441 char_filters: vec![],
442 cors_options: CorsOptions {
443 allowed_origins: vec!["*".to_string()],
444 max_age_in_seconds: 300,
445 },
446 encryption_key: None,
447 similarity: Similarity {
448 odata_type: "#Microsoft.Azure.Search.BM25Similarity".to_string(),
449 b: None,
450 k1: None,
451 },
452 semantic: Semantic {
453 default_configuration: format!("{}-semantic-configuration", index_name),
454 configurations: vec![SemanticConfiguration {
455 name: format!("{}-semantic-configuration", index_name),
456 prioritized_fields: SemanticConfigurationPrioritizedFields {
457 title_field: FieldDescriptor {
458 field_name: "title".to_string(),
459 },
460 prioritized_content_fields: vec![FieldDescriptor {
461 field_name: "chunk".to_string(),
462 }],
463 prioritized_keywords_fields: vec![],
464 },
465 }],
466 },
467 vector_search: VectorSearch {
468 profiles: vec![Profile {
469 name: format!("{}-azureOpenAi-text-profile", index_name),
470 algorithm: format!("{}-algorithm", index_name),
471 vectorizer: Some(format!("{}-azureOpenAi-text-vectorizer", index_name)),
472 compression: None,
473 }],
474 algorithms: vec![Algorithm {
475 name: format!("{}-algorithm", index_name),
476 kind: "hnsw".to_string(),
477 hnsw_parameters: Some(HnswParameters {
478 m: 4,
479 metric: "cosine".to_string(),
480 ef_construction: 400,
481 ef_search: 500,
482 }),
483 exhaustive_knn_parameters: None,
484 }],
485 vectorizers: vec![Vectorizer {
486 name: format!("{}-azureOpenAi-text-vectorizer", index_name),
487 kind: "azureOpenAI".to_string(),
488 azure_open_ai_parameters: AzureOpenAiParameters {
489 resource_uri: search_config.vectorizer_resource_uri.clone(),
490 deployment_id: search_config.vectorizer_deployment_id.clone(),
491 api_key: search_config.vectorizer_api_key.clone(),
492 model_name: search_config.vectorizer_model_name.clone(),
493 auth_identity: None,
494 },
495 custom_web_api_parameters: None,
496 }],
497 compressions: vec![],
498 },
499 };
500
501 let index_json = serde_json::to_string(&index)?;
502
503 let mut url = search_config.search_endpoint.clone();
504 url.set_path("/indexes");
505 url.set_query(Some(&format!("api-version={}", API_VERSION)));
506
507 let response = REQWEST_CLIENT
508 .post(url)
509 .header("Content-Type", "application/json")
510 .header("api-key", search_config.search_api_key.clone())
511 .body(index_json)
512 .send()
513 .await?;
514
515 if response.status().is_success() {
517 println!("Index created successfully: {}", index_name);
518 Ok(())
519 } else {
520 let status = response.status();
521 let error_text = response.text().await?;
522 Err(anyhow::anyhow!(
523 "Failed to create index. Status: {}. Error: {}",
524 status,
525 error_text
526 ))
527 }
528}
529
530#[derive(Serialize, Deserialize)]
531pub struct IndexAction<T> {
532 #[serde(rename = "@search.action")]
533 pub search_action: String,
534 pub document: T,
535}
536
537#[derive(Serialize, Deserialize)]
538pub struct IndexBatch<T> {
539 pub value: Vec<IndexAction<T>>,
540}
541
542pub async fn add_documents_to_index<T>(
543 index_name: &str,
544 documents: Vec<T>,
545 app_config: &ApplicationConfiguration,
546) -> anyhow::Result<()>
547where
548 T: Serialize,
549{
550 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
551 anyhow::anyhow!("Azure configuration is missing from the application configuration")
552 })?;
553
554 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
555 anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
556 })?;
557
558 let mut url = search_config.search_endpoint.clone();
559 url.set_path(&format!("indexes('{}')/docs/index", index_name));
560 url.set_query(Some(&format!("api-version={}", API_VERSION)));
561
562 let index_actions: Vec<IndexAction<String>> = documents
563 .into_iter()
564 .map(|doc| IndexAction {
565 search_action: "upload".to_string(),
566 document: serde_json::to_string(&doc).unwrap(),
567 })
568 .collect();
569
570 let batch = IndexBatch {
571 value: index_actions,
572 };
573
574 let batch_json = serde_json::to_string(&batch)?;
575
576 let response = REQWEST_CLIENT
577 .post(url)
578 .header("Content-Type", "application/json")
579 .header("api-key", search_config.search_api_key.clone())
580 .body(batch_json)
581 .send()
582 .await?;
583
584 if response.status().is_success() {
585 println!("Documents added successfully to index: {}", index_name);
586 Ok(())
587 } else {
588 let status = response.status();
589 let error_text = response.text().await?;
590 Err(anyhow::anyhow!(
591 "Failed to add documents to index. Status: {}. Error: {}",
592 status,
593 error_text
594 ))
595 }
596}