1use secrecy::ExposeSecret;
2
3use crate::prelude::*;
4use headless_lms_base::config::ApplicationConfiguration;
5use headless_lms_utils::http::REQWEST_CLIENT;
6
7const API_VERSION: &str = "2024-07-01";
8
9#[derive(Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct NewIndex {
12 name: String,
13 fields: Vec<Field>,
14 scoring_profiles: Vec<ScoringProfile>,
15 default_scoring_profile: Option<String>,
16 suggesters: Vec<Suggester>,
17 analyzers: Vec<Analyzer>,
18 tokenizers: Vec<serde_json::Value>,
19 token_filters: Vec<serde_json::Value>,
20 char_filters: Vec<serde_json::Value>,
21 cors_options: CorsOptions,
22 encryption_key: Option<EncryptionKey>,
23 similarity: Similarity,
24 semantic: Semantic,
25 vector_search: VectorSearch,
26}
27
28#[derive(Serialize, Deserialize)]
29#[serde(rename_all = "camelCase")]
30pub struct Analyzer {
31 name: String,
32 #[serde(rename = "@odata.type")]
33 odata_type: String,
34 char_filters: Vec<String>,
35 tokenizer: String,
36}
37
38#[derive(Serialize, Deserialize)]
39#[serde(rename_all = "camelCase")]
40pub struct CorsOptions {
41 allowed_origins: Vec<String>,
42 max_age_in_seconds: i64,
43}
44
45#[derive(Serialize, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct EncryptionKey {
48 key_vault_key_name: String,
49 key_vault_key_version: String,
50 key_vault_uri: String,
51 access_credentials: AccessCredentials,
52}
53
54#[derive(Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct AccessCredentials {
57 application_id: String,
58 application_secret: String,
59}
60
61#[derive(Serialize, Deserialize)]
62#[serde(rename_all = "camelCase")]
63pub struct Field {
64 name: String,
65 #[serde(rename = "type")]
66 field_type: String,
67 key: Option<bool>,
68 searchable: Option<bool>,
69 filterable: Option<bool>,
70 sortable: Option<bool>,
71 facetable: Option<bool>,
72 retrievable: Option<bool>,
73 index_analyzer: Option<String>,
74 search_analyzer: Option<String>,
75 analyzer: Option<String>,
76 synonym_maps: Option<Vec<String>>,
77 dimensions: Option<i64>,
78 vector_search_profile: Option<String>,
79 stored: Option<bool>,
80 vector_encoding: Option<serde_json::Value>,
81}
82
83#[derive(Serialize, Deserialize)]
84pub struct ScoringProfile {
85 name: String,
86 text: Text,
87 functions: Vec<Function>,
88}
89
90#[derive(Serialize, Deserialize)]
91#[serde(rename_all = "camelCase")]
92pub struct Function {
93 #[serde(rename = "type")]
94 function_type: String,
95 boost: i64,
96 field_name: String,
97 interpolation: String,
98 distance: Distance,
99}
100
101#[derive(Serialize, Deserialize)]
102#[serde(rename_all = "camelCase")]
103pub struct Distance {
104 reference_point_parameter: String,
105 boosting_distance: i64,
106}
107
108#[derive(Serialize, Deserialize)]
109pub struct Text {
110 weights: Weights,
111}
112
113#[derive(Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct Weights {
116 hotel_name: i64,
117}
118
119#[derive(Serialize, Deserialize)]
120#[serde(rename_all = "camelCase")]
121pub struct Semantic {
122 default_configuration: String,
123 configurations: Vec<SemanticConfiguration>,
124}
125
126#[derive(Serialize, Deserialize)]
127#[serde(rename_all = "camelCase")]
128pub struct SemanticConfiguration {
129 name: String,
130 prioritized_fields: SemanticConfigurationPrioritizedFields,
131}
132
133#[derive(Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct SemanticConfigurationPrioritizedFields {
136 title_field: FieldDescriptor,
137 prioritized_content_fields: Vec<FieldDescriptor>,
138 prioritized_keywords_fields: Vec<FieldDescriptor>,
139}
140
141#[derive(Serialize, Deserialize)]
142#[serde(rename_all = "camelCase")]
143pub struct FieldDescriptor {
144 field_name: String,
145}
146
147#[derive(Serialize, Deserialize)]
148pub struct Similarity {
149 #[serde(rename = "@odata.type")]
150 odata_type: String,
151 b: Option<f64>,
152 k1: Option<f64>,
153}
154
155#[derive(Serialize, Deserialize)]
156#[serde(rename_all = "camelCase")]
157pub struct Suggester {
158 name: String,
159 search_mode: String,
160 source_fields: Vec<String>,
161}
162
163#[derive(Serialize, Deserialize)]
164pub struct VectorSearch {
165 profiles: Vec<Profile>,
166 algorithms: Vec<Algorithm>,
167 compressions: Vec<Compression>,
168 vectorizers: Vec<Vectorizer>,
169}
170
171#[derive(Serialize, Deserialize)]
172#[serde(rename_all = "camelCase")]
173pub struct Vectorizer {
174 name: String,
175 kind: String,
176 #[serde(rename = "azureOpenAIParameters")]
177 azure_open_ai_parameters: AzureOpenAiParameters,
178 custom_web_api_parameters: Option<serde_json::Value>,
179}
180
181#[derive(Serialize, Deserialize)]
182#[serde(rename_all = "camelCase")]
183pub struct AzureOpenAiParameters {
184 resource_uri: String,
185 deployment_id: String,
186 api_key: String,
187 model_name: String,
188 auth_identity: Option<serde_json::Value>,
189}
190
191#[derive(Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct Algorithm {
194 name: String,
195 kind: String,
196 hnsw_parameters: Option<HnswParameters>,
197 exhaustive_knn_parameters: Option<ExhaustiveKnnParameters>,
198}
199
200#[derive(Serialize, Deserialize)]
201pub struct ExhaustiveKnnParameters {
202 metric: String,
203}
204
205#[derive(Serialize, Deserialize)]
206#[serde(rename_all = "camelCase")]
207pub struct HnswParameters {
208 m: i64,
209 metric: String,
210 ef_construction: i64,
211 ef_search: i64,
212}
213
214#[derive(Serialize, Deserialize)]
215#[serde(rename_all = "camelCase")]
216pub struct Compression {
217 name: String,
218 kind: String,
219 scalar_quantization_parameters: Option<ScalarQuantizationParameters>,
220 rerank_with_original_vectors: bool,
221 default_oversampling: i64,
222}
223
224#[derive(Serialize, Deserialize)]
225#[serde(rename_all = "camelCase")]
226pub struct ScalarQuantizationParameters {
227 quantized_data_type: String,
228}
229
230#[derive(Serialize, Deserialize)]
231pub struct Profile {
232 name: String,
233 algorithm: String,
234 compression: Option<String>,
235 vectorizer: Option<String>,
236}
237
238pub async fn does_search_index_exist(
239 index_name: &str,
240 app_config: &ApplicationConfiguration,
241) -> anyhow::Result<bool> {
242 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
243 anyhow::anyhow!("Azure configuration is missing from the application configuration")
244 })?;
245
246 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
247 anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
248 })?;
249
250 let mut url = search_config.search_endpoint.clone();
251 url.set_path(&format!("indexes('{}')", index_name));
252 url.set_query(Some(&format!("api-version={}", API_VERSION)));
253
254 let response = REQWEST_CLIENT
255 .get(url)
256 .header("Content-Type", "application/json")
257 .header("api-key", search_config.search_api_key.expose_secret())
258 .send()
259 .await?;
260
261 if response.status().is_success() {
262 Ok(true)
263 } else if response.status() == 404 {
264 Ok(false)
265 } else {
266 let status = response.status();
267 let error_text = response.text().await?;
268 Err(anyhow::anyhow!(
269 "Error checking if index exists. Status: {}. Error: {}",
270 status,
271 error_text
272 ))
273 }
274}
275
276pub async fn create_search_index(
277 index_name: String,
278 app_config: &ApplicationConfiguration,
279) -> anyhow::Result<()> {
280 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
281 anyhow::anyhow!("Azure configuration is missing from the application configuration")
282 })?;
283
284 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
285 anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
286 })?;
287
288 let fields = vec![
289 Field {
290 name: "chunk_id".to_string(),
291 field_type: "Edm.String".to_string(),
292 key: Some(true),
293 searchable: Some(true),
294 filterable: Some(true),
295 retrievable: Some(true),
296 stored: Some(true),
297 sortable: Some(false),
298 facetable: Some(false),
299 analyzer: Some("keyword".to_string()),
300 index_analyzer: None,
301 search_analyzer: None,
302 synonym_maps: Some(vec![]),
303 dimensions: None,
304 vector_search_profile: None,
305 vector_encoding: None,
306 },
307 Field {
308 name: "language".to_string(),
309 field_type: "Edm.String".to_string(),
310 key: Some(false),
311 searchable: Some(true),
312 filterable: Some(true),
313 retrievable: Some(true),
314 stored: Some(true),
315 sortable: Some(false),
316 facetable: Some(false),
317 analyzer: None,
318 index_analyzer: None,
319 search_analyzer: None,
320 synonym_maps: Some(vec![]),
321 dimensions: None,
322 vector_search_profile: None,
323 vector_encoding: None,
324 },
325 Field {
326 name: "parent_id".to_string(),
327 field_type: "Edm.String".to_string(),
328 key: Some(false),
329 searchable: Some(true),
330 filterable: Some(true),
331 retrievable: Some(true),
332 stored: Some(true),
333 sortable: Some(true),
334 facetable: Some(true),
335 analyzer: None,
336 index_analyzer: None,
337 search_analyzer: None,
338 synonym_maps: Some(vec![]),
339 dimensions: None,
340 vector_search_profile: None,
341 vector_encoding: None,
342 },
343 Field {
344 name: "chunk".to_string(),
345 field_type: "Edm.String".to_string(),
346 key: Some(false),
347 searchable: Some(true),
348 filterable: Some(false),
349 retrievable: Some(true),
350 stored: Some(true),
351 sortable: Some(false),
352 facetable: Some(false),
353 analyzer: Some("keyword".to_string()),
354 index_analyzer: None,
355 search_analyzer: None,
356 synonym_maps: Some(vec![]),
357 dimensions: None,
358 vector_search_profile: None,
359 vector_encoding: None,
360 },
361 Field {
362 name: "title".to_string(),
363 field_type: "Edm.String".to_string(),
364 key: Some(false),
365 searchable: Some(true),
366 filterable: Some(true),
367 retrievable: Some(true),
368 stored: Some(true),
369 sortable: Some(false),
370 facetable: Some(false),
371 analyzer: None,
372 index_analyzer: None,
373 search_analyzer: None,
374 synonym_maps: Some(vec![]),
375 dimensions: None,
376 vector_search_profile: None,
377 vector_encoding: None,
378 },
379 Field {
380 name: "url".to_string(),
381 field_type: "Edm.String".to_string(),
382 key: Some(false),
383 searchable: Some(false),
384 filterable: Some(true),
385 retrievable: Some(true),
386 stored: Some(true),
387 sortable: Some(false),
388 facetable: Some(false),
389 analyzer: None,
390 index_analyzer: None,
391 search_analyzer: None,
392 synonym_maps: Some(vec![]),
393 dimensions: None,
394 vector_search_profile: None,
395 vector_encoding: None,
396 },
397 Field {
398 name: "course_id".to_string(),
399 field_type: "Edm.String".to_string(),
400 key: Some(false),
401 searchable: Some(false),
402 filterable: Some(true),
403 retrievable: Some(true),
404 stored: Some(true),
405 sortable: Some(false),
406 facetable: Some(false),
407 analyzer: None,
408 index_analyzer: None,
409 search_analyzer: None,
410 synonym_maps: Some(vec![]),
411 dimensions: None,
412 vector_search_profile: None,
413 vector_encoding: None,
414 },
415 Field {
416 name: "text_vector".to_string(),
417 field_type: "Collection(Edm.Single)".to_string(),
418 key: Some(false),
419 searchable: Some(true),
420 filterable: Some(false),
421 retrievable: Some(true),
422 stored: Some(true),
423 sortable: Some(false),
424 facetable: Some(false),
425 analyzer: None,
426 index_analyzer: None,
427 search_analyzer: None,
428 synonym_maps: Some(vec![]),
429 dimensions: Some(1536),
430 vector_search_profile: Some(format!("{}-azureOpenAi-text-profile", index_name)),
431 vector_encoding: None,
432 },
433 Field {
434 name: "filepath".to_string(),
435 field_type: "Edm.String".to_string(),
436 key: Some(false),
437 searchable: Some(true),
438 filterable: Some(true),
439 retrievable: Some(true),
440 stored: Some(true),
441 sortable: Some(false),
442 facetable: Some(false),
443 analyzer: None,
444 index_analyzer: None,
445 search_analyzer: None,
446 synonym_maps: Some(vec![]),
447 dimensions: None,
448 vector_search_profile: None,
449 vector_encoding: None,
450 },
451 Field {
452 name: "chunk_context".to_string(),
453 field_type: "Edm.String".to_string(),
454 key: Some(false),
455 searchable: Some(true),
456 filterable: Some(false),
457 retrievable: Some(true),
458 stored: Some(true),
459 sortable: Some(false),
460 facetable: Some(false),
461 analyzer: None,
462 index_analyzer: None,
463 search_analyzer: None,
464 synonym_maps: Some(vec![]),
465 dimensions: None,
466 vector_search_profile: None,
467 vector_encoding: None,
468 },
469 ];
470
471 let index = NewIndex {
472 name: index_name.clone(),
473 fields,
474 scoring_profiles: vec![],
475 default_scoring_profile: None,
476 suggesters: vec![],
477 analyzers: vec![],
478 tokenizers: vec![],
479 token_filters: vec![],
480 char_filters: vec![],
481 cors_options: CorsOptions {
482 allowed_origins: vec!["*".to_string()],
483 max_age_in_seconds: 300,
484 },
485 encryption_key: None,
486 similarity: Similarity {
487 odata_type: "#Microsoft.Azure.Search.BM25Similarity".to_string(),
488 b: None,
489 k1: None,
490 },
491 semantic: Semantic {
492 default_configuration: format!("{}-semantic-configuration", index_name),
493 configurations: vec![SemanticConfiguration {
494 name: format!("{}-semantic-configuration", index_name),
495 prioritized_fields: SemanticConfigurationPrioritizedFields {
496 title_field: FieldDescriptor {
497 field_name: "title".to_string(),
498 },
499 prioritized_content_fields: vec![FieldDescriptor {
500 field_name: "chunk".to_string(),
501 }],
502 prioritized_keywords_fields: vec![],
503 },
504 }],
505 },
506 vector_search: VectorSearch {
507 profiles: vec![Profile {
508 name: format!("{}-azureOpenAi-text-profile", index_name),
509 algorithm: format!("{}-algorithm", index_name),
510 vectorizer: Some(format!("{}-azureOpenAi-text-vectorizer", index_name)),
511 compression: None,
512 }],
513 algorithms: vec![Algorithm {
514 name: format!("{}-algorithm", index_name),
515 kind: "hnsw".to_string(),
516 hnsw_parameters: Some(HnswParameters {
517 m: 4,
518 metric: "cosine".to_string(),
519 ef_construction: 400,
520 ef_search: 500,
521 }),
522 exhaustive_knn_parameters: None,
523 }],
524 vectorizers: vec![Vectorizer {
525 name: format!("{}-azureOpenAi-text-vectorizer", index_name),
526 kind: "azureOpenAI".to_string(),
527 azure_open_ai_parameters: AzureOpenAiParameters {
528 resource_uri: search_config.vectorizer_resource_uri.clone(),
529 deployment_id: search_config.vectorizer_deployment_id.clone(),
530 api_key: search_config.vectorizer_api_key.expose_secret().to_string(),
531 model_name: search_config.vectorizer_model_name.clone(),
532 auth_identity: None,
533 },
534 custom_web_api_parameters: None,
535 }],
536 compressions: vec![],
537 },
538 };
539
540 let index_json = serde_json::to_string(&index)?;
541
542 let mut url = search_config.search_endpoint.clone();
543 url.set_path("/indexes");
544 url.set_query(Some(&format!("api-version={}", API_VERSION)));
545
546 let response = REQWEST_CLIENT
547 .post(url)
548 .header("Content-Type", "application/json")
549 .header("api-key", search_config.search_api_key.expose_secret())
550 .body(index_json)
551 .send()
552 .await?;
553
554 if response.status().is_success() {
556 println!("Index created successfully: {}", index_name);
557 Ok(())
558 } else {
559 let status = response.status();
560 let error_text = response.text().await?;
561 Err(anyhow::anyhow!(
562 "Failed to create index. Status: {}. Error: {}",
563 status,
564 error_text
565 ))
566 }
567}
568
569#[derive(Serialize, Deserialize)]
570pub struct IndexAction<T> {
571 #[serde(rename = "@search.action")]
572 pub search_action: String,
573 pub document: T,
574}
575
576#[derive(Serialize, Deserialize)]
577pub struct IndexBatch<T> {
578 pub value: Vec<IndexAction<T>>,
579}
580
581pub async fn add_documents_to_index<T>(
582 index_name: &str,
583 documents: Vec<T>,
584 app_config: &ApplicationConfiguration,
585) -> anyhow::Result<()>
586where
587 T: Serialize,
588{
589 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
590 anyhow::anyhow!("Azure configuration is missing from the application configuration")
591 })?;
592
593 let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
594 anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
595 })?;
596
597 let mut url = search_config.search_endpoint.clone();
598 url.set_path(&format!("indexes('{}')/docs/index", index_name));
599 url.set_query(Some(&format!("api-version={}", API_VERSION)));
600
601 let index_actions: anyhow::Result<Vec<IndexAction<String>>> = documents
602 .into_iter()
603 .map(|doc| {
604 serde_json::to_string(&doc)
605 .map(|document| IndexAction {
606 search_action: "upload".to_string(),
607 document,
608 })
609 .map_err(|e| anyhow::anyhow!("Failed to serialize document: {}", e))
610 })
611 .collect();
612 let index_actions = index_actions?;
613
614 let batch = IndexBatch {
615 value: index_actions,
616 };
617
618 let batch_json = serde_json::to_string(&batch)?;
619
620 let response = REQWEST_CLIENT
621 .post(url)
622 .header("Content-Type", "application/json")
623 .header("api-key", search_config.search_api_key.expose_secret())
624 .body(batch_json)
625 .send()
626 .await?;
627
628 if response.status().is_success() {
629 println!("Documents added successfully to index: {}", index_name);
630 Ok(())
631 } else {
632 let status = response.status();
633 let error_text = response.text().await?;
634 Err(anyhow::anyhow!(
635 "Failed to add documents to index. Status: {}. Error: {}",
636 status,
637 error_text
638 ))
639 }
640}