1use crate::prelude::*;
2use headless_lms_base::config::ApplicationConfiguration;
3use headless_lms_utils::http::REQWEST_CLIENT;
4
5const API_VERSION: &str = "2024-07-01";
6
7#[derive(Serialize, Deserialize)]
8#[serde(rename_all = "camelCase")]
9pub struct NewIndex {
10 name: String,
11 fields: Vec<Field>,
12 scoring_profiles: Vec<ScoringProfile>,
13 default_scoring_profile: Option<String>,
14 suggesters: Vec<Suggester>,
15 analyzers: Vec<Analyzer>,
16 tokenizers: Vec<serde_json::Value>,
17 token_filters: Vec<serde_json::Value>,
18 char_filters: Vec<serde_json::Value>,
19 cors_options: CorsOptions,
20 encryption_key: Option<EncryptionKey>,
21 similarity: Similarity,
22 semantic: Semantic,
23 vector_search: VectorSearch,
24}
25
26#[derive(Serialize, Deserialize)]
27#[serde(rename_all = "camelCase")]
28pub struct Analyzer {
29 name: String,
30 #[serde(rename = "@odata.type")]
31 odata_type: String,
32 char_filters: Vec<String>,
33 tokenizer: String,
34}
35
36#[derive(Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct CorsOptions {
39 allowed_origins: Vec<String>,
40 max_age_in_seconds: i64,
41}
42
43#[derive(Serialize, Deserialize)]
44#[serde(rename_all = "camelCase")]
45pub struct EncryptionKey {
46 key_vault_key_name: String,
47 key_vault_key_version: String,
48 key_vault_uri: String,
49 access_credentials: AccessCredentials,
50}
51
52#[derive(Serialize, Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct AccessCredentials {
55 application_id: String,
56 application_secret: String,
57}
58
59#[derive(Serialize, Deserialize)]
60#[serde(rename_all = "camelCase")]
61pub struct Field {
62 name: String,
63 #[serde(rename = "type")]
64 field_type: String,
65 key: Option<bool>,
66 searchable: Option<bool>,
67 filterable: Option<bool>,
68 sortable: Option<bool>,
69 facetable: Option<bool>,
70 retrievable: Option<bool>,
71 index_analyzer: Option<String>,
72 search_analyzer: Option<String>,
73 analyzer: Option<String>,
74 synonym_maps: Option<Vec<String>>,
75 dimensions: Option<i64>,
76 vector_search_profile: Option<String>,
77 stored: Option<bool>,
78 vector_encoding: Option<serde_json::Value>,
79}
80
81#[derive(Serialize, Deserialize)]
82pub struct ScoringProfile {
83 name: String,
84 text: Text,
85 functions: Vec<Function>,
86}
87
88#[derive(Serialize, Deserialize)]
89#[serde(rename_all = "camelCase")]
90pub struct Function {
91 #[serde(rename = "type")]
92 function_type: String,
93 boost: i64,
94 field_name: String,
95 interpolation: String,
96 distance: Distance,
97}
98
99#[derive(Serialize, Deserialize)]
100#[serde(rename_all = "camelCase")]
101pub struct Distance {
102 reference_point_parameter: String,
103 boosting_distance: i64,
104}
105
106#[derive(Serialize, Deserialize)]
107pub struct Text {
108 weights: Weights,
109}
110
111#[derive(Serialize, Deserialize)]
112#[serde(rename_all = "camelCase")]
113pub struct Weights {
114 hotel_name: i64,
115}
116
117#[derive(Serialize, Deserialize)]
118#[serde(rename_all = "camelCase")]
119pub struct Semantic {
120 default_configuration: String,
121 configurations: Vec<SemanticConfiguration>,
122}
123
124#[derive(Serialize, Deserialize)]
125#[serde(rename_all = "camelCase")]
126pub struct SemanticConfiguration {
127 name: String,
128 prioritized_fields: SemanticConfigurationPrioritizedFields,
129}
130
131#[derive(Serialize, Deserialize)]
132#[serde(rename_all = "camelCase")]
133pub struct SemanticConfigurationPrioritizedFields {
134 title_field: FieldDescriptor,
135 prioritized_content_fields: Vec<FieldDescriptor>,
136 prioritized_keywords_fields: Vec<FieldDescriptor>,
137}
138
139#[derive(Serialize, Deserialize)]
140#[serde(rename_all = "camelCase")]
141pub struct FieldDescriptor {
142 field_name: String,
143}
144
145#[derive(Serialize, Deserialize)]
146pub struct Similarity {
147 #[serde(rename = "@odata.type")]
148 odata_type: String,
149 b: Option<f64>,
150 k1: Option<f64>,
151}
152
153#[derive(Serialize, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct Suggester {
156 name: String,
157 search_mode: String,
158 source_fields: Vec<String>,
159}
160
161#[derive(Serialize, Deserialize)]
162pub struct VectorSearch {
163 profiles: Vec<Profile>,
164 algorithms: Vec<Algorithm>,
165 compressions: Vec<Compression>,
166 vectorizers: Vec<Vectorizer>,
167}
168
169#[derive(Serialize, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct Vectorizer {
172 name: String,
173 kind: String,
174 #[serde(rename = "azureOpenAIParameters")]
175 azure_open_ai_parameters: AzureOpenAiParameters,
176 custom_web_api_parameters: Option<serde_json::Value>,
177}
178
179#[derive(Serialize, Deserialize)]
180#[serde(rename_all = "camelCase")]
181pub struct AzureOpenAiParameters {
182 resource_uri: String,
183 deployment_id: String,
184 api_key: String,
185 model_name: String,
186 auth_identity: Option<serde_json::Value>,
187}
188
189#[derive(Serialize, Deserialize)]
190#[serde(rename_all = "camelCase")]
191pub struct Algorithm {
192 name: String,
193 kind: String,
194 hnsw_parameters: Option<HnswParameters>,
195 exhaustive_knn_parameters: Option<ExhaustiveKnnParameters>,
196}
197
198#[derive(Serialize, Deserialize)]
199pub struct ExhaustiveKnnParameters {
200 metric: String,
201}
202
203#[derive(Serialize, Deserialize)]
204#[serde(rename_all = "camelCase")]
205pub struct HnswParameters {
206 m: i64,
207 metric: String,
208 ef_construction: i64,
209 ef_search: i64,
210}
211
212#[derive(Serialize, Deserialize)]
213#[serde(rename_all = "camelCase")]
214pub struct Compression {
215 name: String,
216 kind: String,
217 scalar_quantization_parameters: Option<ScalarQuantizationParameters>,
218 rerank_with_original_vectors: bool,
219 default_oversampling: i64,
220}
221
222#[derive(Serialize, Deserialize)]
223#[serde(rename_all = "camelCase")]
224pub struct ScalarQuantizationParameters {
225 quantized_data_type: String,
226}
227
228#[derive(Serialize, Deserialize)]
229pub struct Profile {
230 name: String,
231 algorithm: String,
232 compression: Option<String>,
233 vectorizer: Option<String>,
234}
235
236pub async fn does_search_index_exist(
237 index_name: &str,
238 app_config: &ApplicationConfiguration,
239) -> anyhow::Result<bool> {
240 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
241 anyhow::anyhow!("Azure configuration is missing from the application configuration")
242 })?;
243
244 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
245 anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
246 })?;
247
248 let mut url = search_config.search_endpoint.clone();
249 url.set_path(&format!("indexes('{}')", index_name));
250 url.set_query(Some(&format!("api-version={}", API_VERSION)));
251
252 let response = REQWEST_CLIENT
253 .get(url)
254 .header("Content-Type", "application/json")
255 .header("api-key", search_config.search_api_key.clone())
256 .send()
257 .await?;
258
259 if response.status().is_success() {
260 Ok(true)
261 } else if response.status() == 404 {
262 Ok(false)
263 } else {
264 let status = response.status();
265 let error_text = response.text().await?;
266 Err(anyhow::anyhow!(
267 "Error checking if index exists. Status: {}. Error: {}",
268 status,
269 error_text
270 ))
271 }
272}
273
274pub async fn create_search_index(
275 index_name: String,
276 app_config: &ApplicationConfiguration,
277) -> anyhow::Result<()> {
278 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
279 anyhow::anyhow!("Azure configuration is missing from the application configuration")
280 })?;
281
282 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
283 anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
284 })?;
285
286 let fields = vec![
287 Field {
288 name: "chunk_id".to_string(),
289 field_type: "Edm.String".to_string(),
290 key: Some(true),
291 searchable: Some(true),
292 filterable: Some(true),
293 retrievable: Some(true),
294 stored: Some(true),
295 sortable: Some(false),
296 facetable: Some(false),
297 analyzer: Some("keyword".to_string()),
298 index_analyzer: None,
299 search_analyzer: None,
300 synonym_maps: Some(vec![]),
301 dimensions: None,
302 vector_search_profile: None,
303 vector_encoding: None,
304 },
305 Field {
306 name: "language".to_string(),
307 field_type: "Edm.String".to_string(),
308 key: Some(false),
309 searchable: Some(true),
310 filterable: Some(true),
311 retrievable: Some(true),
312 stored: Some(true),
313 sortable: Some(false),
314 facetable: Some(false),
315 analyzer: None,
316 index_analyzer: None,
317 search_analyzer: None,
318 synonym_maps: Some(vec![]),
319 dimensions: None,
320 vector_search_profile: None,
321 vector_encoding: None,
322 },
323 Field {
324 name: "parent_id".to_string(),
325 field_type: "Edm.String".to_string(),
326 key: Some(false),
327 searchable: Some(true),
328 filterable: Some(true),
329 retrievable: Some(true),
330 stored: Some(true),
331 sortable: Some(true),
332 facetable: Some(true),
333 analyzer: None,
334 index_analyzer: None,
335 search_analyzer: None,
336 synonym_maps: Some(vec![]),
337 dimensions: None,
338 vector_search_profile: None,
339 vector_encoding: None,
340 },
341 Field {
342 name: "chunk".to_string(),
343 field_type: "Edm.String".to_string(),
344 key: Some(false),
345 searchable: Some(true),
346 filterable: Some(false),
347 retrievable: Some(true),
348 stored: Some(true),
349 sortable: Some(false),
350 facetable: Some(false),
351 analyzer: Some("keyword".to_string()),
352 index_analyzer: None,
353 search_analyzer: None,
354 synonym_maps: Some(vec![]),
355 dimensions: None,
356 vector_search_profile: None,
357 vector_encoding: None,
358 },
359 Field {
360 name: "title".to_string(),
361 field_type: "Edm.String".to_string(),
362 key: Some(false),
363 searchable: Some(true),
364 filterable: Some(true),
365 retrievable: Some(true),
366 stored: Some(true),
367 sortable: Some(false),
368 facetable: Some(false),
369 analyzer: None,
370 index_analyzer: None,
371 search_analyzer: None,
372 synonym_maps: Some(vec![]),
373 dimensions: None,
374 vector_search_profile: None,
375 vector_encoding: None,
376 },
377 Field {
378 name: "url".to_string(),
379 field_type: "Edm.String".to_string(),
380 key: Some(false),
381 searchable: Some(false),
382 filterable: Some(true),
383 retrievable: Some(true),
384 stored: Some(true),
385 sortable: Some(false),
386 facetable: Some(false),
387 analyzer: None,
388 index_analyzer: None,
389 search_analyzer: None,
390 synonym_maps: Some(vec![]),
391 dimensions: None,
392 vector_search_profile: None,
393 vector_encoding: None,
394 },
395 Field {
396 name: "course_id".to_string(),
397 field_type: "Edm.String".to_string(),
398 key: Some(false),
399 searchable: Some(false),
400 filterable: Some(true),
401 retrievable: Some(true),
402 stored: Some(true),
403 sortable: Some(false),
404 facetable: Some(false),
405 analyzer: None,
406 index_analyzer: None,
407 search_analyzer: None,
408 synonym_maps: Some(vec![]),
409 dimensions: None,
410 vector_search_profile: None,
411 vector_encoding: None,
412 },
413 Field {
414 name: "text_vector".to_string(),
415 field_type: "Collection(Edm.Single)".to_string(),
416 key: Some(false),
417 searchable: Some(true),
418 filterable: Some(false),
419 retrievable: Some(true),
420 stored: Some(true),
421 sortable: Some(false),
422 facetable: Some(false),
423 analyzer: None,
424 index_analyzer: None,
425 search_analyzer: None,
426 synonym_maps: Some(vec![]),
427 dimensions: Some(1536),
428 vector_search_profile: Some(format!("{}-azureOpenAi-text-profile", index_name)),
429 vector_encoding: None,
430 },
431 Field {
432 name: "filepath".to_string(),
433 field_type: "Edm.String".to_string(),
434 key: Some(false),
435 searchable: Some(true),
436 filterable: Some(true),
437 retrievable: Some(true),
438 stored: Some(true),
439 sortable: Some(false),
440 facetable: Some(false),
441 analyzer: None,
442 index_analyzer: None,
443 search_analyzer: None,
444 synonym_maps: Some(vec![]),
445 dimensions: None,
446 vector_search_profile: None,
447 vector_encoding: None,
448 },
449 Field {
450 name: "chunk_context".to_string(),
451 field_type: "Edm.String".to_string(),
452 key: Some(false),
453 searchable: Some(true),
454 filterable: Some(false),
455 retrievable: Some(true),
456 stored: Some(true),
457 sortable: Some(false),
458 facetable: Some(false),
459 analyzer: None,
460 index_analyzer: None,
461 search_analyzer: None,
462 synonym_maps: Some(vec![]),
463 dimensions: None,
464 vector_search_profile: None,
465 vector_encoding: None,
466 },
467 ];
468
469 let index = NewIndex {
470 name: index_name.clone(),
471 fields,
472 scoring_profiles: vec![],
473 default_scoring_profile: None,
474 suggesters: vec![],
475 analyzers: vec![],
476 tokenizers: vec![],
477 token_filters: vec![],
478 char_filters: vec![],
479 cors_options: CorsOptions {
480 allowed_origins: vec!["*".to_string()],
481 max_age_in_seconds: 300,
482 },
483 encryption_key: None,
484 similarity: Similarity {
485 odata_type: "#Microsoft.Azure.Search.BM25Similarity".to_string(),
486 b: None,
487 k1: None,
488 },
489 semantic: Semantic {
490 default_configuration: format!("{}-semantic-configuration", index_name),
491 configurations: vec![SemanticConfiguration {
492 name: format!("{}-semantic-configuration", index_name),
493 prioritized_fields: SemanticConfigurationPrioritizedFields {
494 title_field: FieldDescriptor {
495 field_name: "title".to_string(),
496 },
497 prioritized_content_fields: vec![FieldDescriptor {
498 field_name: "chunk".to_string(),
499 }],
500 prioritized_keywords_fields: vec![],
501 },
502 }],
503 },
504 vector_search: VectorSearch {
505 profiles: vec![Profile {
506 name: format!("{}-azureOpenAi-text-profile", index_name),
507 algorithm: format!("{}-algorithm", index_name),
508 vectorizer: Some(format!("{}-azureOpenAi-text-vectorizer", index_name)),
509 compression: None,
510 }],
511 algorithms: vec![Algorithm {
512 name: format!("{}-algorithm", index_name),
513 kind: "hnsw".to_string(),
514 hnsw_parameters: Some(HnswParameters {
515 m: 4,
516 metric: "cosine".to_string(),
517 ef_construction: 400,
518 ef_search: 500,
519 }),
520 exhaustive_knn_parameters: None,
521 }],
522 vectorizers: vec![Vectorizer {
523 name: format!("{}-azureOpenAi-text-vectorizer", index_name),
524 kind: "azureOpenAI".to_string(),
525 azure_open_ai_parameters: AzureOpenAiParameters {
526 resource_uri: search_config.vectorizer_resource_uri.clone(),
527 deployment_id: search_config.vectorizer_deployment_id.clone(),
528 api_key: search_config.vectorizer_api_key.clone(),
529 model_name: search_config.vectorizer_model_name.clone(),
530 auth_identity: None,
531 },
532 custom_web_api_parameters: None,
533 }],
534 compressions: vec![],
535 },
536 };
537
538 let index_json = serde_json::to_string(&index)?;
539
540 let mut url = search_config.search_endpoint.clone();
541 url.set_path("/indexes");
542 url.set_query(Some(&format!("api-version={}", API_VERSION)));
543
544 let response = REQWEST_CLIENT
545 .post(url)
546 .header("Content-Type", "application/json")
547 .header("api-key", search_config.search_api_key.clone())
548 .body(index_json)
549 .send()
550 .await?;
551
552 if response.status().is_success() {
554 println!("Index created successfully: {}", index_name);
555 Ok(())
556 } else {
557 let status = response.status();
558 let error_text = response.text().await?;
559 Err(anyhow::anyhow!(
560 "Failed to create index. Status: {}. Error: {}",
561 status,
562 error_text
563 ))
564 }
565}
566
567#[derive(Serialize, Deserialize)]
568pub struct IndexAction<T> {
569 #[serde(rename = "@search.action")]
570 pub search_action: String,
571 pub document: T,
572}
573
574#[derive(Serialize, Deserialize)]
575pub struct IndexBatch<T> {
576 pub value: Vec<IndexAction<T>>,
577}
578
579pub async fn add_documents_to_index<T>(
580 index_name: &str,
581 documents: Vec<T>,
582 app_config: &ApplicationConfiguration,
583) -> anyhow::Result<()>
584where
585 T: Serialize,
586{
587 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
588 anyhow::anyhow!("Azure configuration is missing from the application configuration")
589 })?;
590
591 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
592 anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
593 })?;
594
595 let mut url = search_config.search_endpoint.clone();
596 url.set_path(&format!("indexes('{}')/docs/index", index_name));
597 url.set_query(Some(&format!("api-version={}", API_VERSION)));
598
599 let index_actions: anyhow::Result<Vec<IndexAction<String>>> = documents
600 .into_iter()
601 .map(|doc| {
602 serde_json::to_string(&doc)
603 .map(|document| IndexAction {
604 search_action: "upload".to_string(),
605 document,
606 })
607 .map_err(|e| anyhow::anyhow!("Failed to serialize document: {}", e))
608 })
609 .collect();
610 let index_actions = index_actions?;
611
612 let batch = IndexBatch {
613 value: index_actions,
614 };
615
616 let batch_json = serde_json::to_string(&batch)?;
617
618 let response = REQWEST_CLIENT
619 .post(url)
620 .header("Content-Type", "application/json")
621 .header("api-key", search_config.search_api_key.clone())
622 .body(batch_json)
623 .send()
624 .await?;
625
626 if response.status().is_success() {
627 println!("Documents added successfully to index: {}", index_name);
628 Ok(())
629 } else {
630 let status = response.status();
631 let error_text = response.text().await?;
632 Err(anyhow::anyhow!(
633 "Failed to add documents to index. Status: {}. Error: {}",
634 status,
635 error_text
636 ))
637 }
638}