headless_lms_chatbot/
azure_search_index.rs

1use crate::prelude::*;
2use headless_lms_base::config::ApplicationConfiguration;
3use headless_lms_utils::http::REQWEST_CLIENT;
4
5const API_VERSION: &str = "2024-07-01";
6
7#[derive(Serialize, Deserialize)]
8#[serde(rename_all = "camelCase")]
9pub struct NewIndex {
10    name: String,
11    fields: Vec<Field>,
12    scoring_profiles: Vec<ScoringProfile>,
13    default_scoring_profile: Option<String>,
14    suggesters: Vec<Suggester>,
15    analyzers: Vec<Analyzer>,
16    tokenizers: Vec<serde_json::Value>,
17    token_filters: Vec<serde_json::Value>,
18    char_filters: Vec<serde_json::Value>,
19    cors_options: CorsOptions,
20    encryption_key: Option<EncryptionKey>,
21    similarity: Similarity,
22    semantic: Semantic,
23    vector_search: VectorSearch,
24}
25
26#[derive(Serialize, Deserialize)]
27#[serde(rename_all = "camelCase")]
28pub struct Analyzer {
29    name: String,
30    #[serde(rename = "@odata.type")]
31    odata_type: String,
32    char_filters: Vec<String>,
33    tokenizer: String,
34}
35
36#[derive(Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct CorsOptions {
39    allowed_origins: Vec<String>,
40    max_age_in_seconds: i64,
41}
42
43#[derive(Serialize, Deserialize)]
44#[serde(rename_all = "camelCase")]
45pub struct EncryptionKey {
46    key_vault_key_name: String,
47    key_vault_key_version: String,
48    key_vault_uri: String,
49    access_credentials: AccessCredentials,
50}
51
52#[derive(Serialize, Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct AccessCredentials {
55    application_id: String,
56    application_secret: String,
57}
58
59#[derive(Serialize, Deserialize)]
60#[serde(rename_all = "camelCase")]
61pub struct Field {
62    name: String,
63    #[serde(rename = "type")]
64    field_type: String,
65    key: Option<bool>,
66    searchable: Option<bool>,
67    filterable: Option<bool>,
68    sortable: Option<bool>,
69    facetable: Option<bool>,
70    retrievable: Option<bool>,
71    index_analyzer: Option<String>,
72    search_analyzer: Option<String>,
73    analyzer: Option<String>,
74    synonym_maps: Option<Vec<String>>,
75    dimensions: Option<i64>,
76    vector_search_profile: Option<String>,
77    stored: Option<bool>,
78    vector_encoding: Option<serde_json::Value>,
79}
80
81#[derive(Serialize, Deserialize)]
82pub struct ScoringProfile {
83    name: String,
84    text: Text,
85    functions: Vec<Function>,
86}
87
88#[derive(Serialize, Deserialize)]
89#[serde(rename_all = "camelCase")]
90pub struct Function {
91    #[serde(rename = "type")]
92    function_type: String,
93    boost: i64,
94    field_name: String,
95    interpolation: String,
96    distance: Distance,
97}
98
99#[derive(Serialize, Deserialize)]
100#[serde(rename_all = "camelCase")]
101pub struct Distance {
102    reference_point_parameter: String,
103    boosting_distance: i64,
104}
105
106#[derive(Serialize, Deserialize)]
107pub struct Text {
108    weights: Weights,
109}
110
111#[derive(Serialize, Deserialize)]
112#[serde(rename_all = "camelCase")]
113pub struct Weights {
114    hotel_name: i64,
115}
116
117#[derive(Serialize, Deserialize)]
118#[serde(rename_all = "camelCase")]
119pub struct Semantic {
120    default_configuration: String,
121    configurations: Vec<SemanticConfiguration>,
122}
123
124#[derive(Serialize, Deserialize)]
125#[serde(rename_all = "camelCase")]
126pub struct SemanticConfiguration {
127    name: String,
128    prioritized_fields: SemanticConfigurationPrioritizedFields,
129}
130
131#[derive(Serialize, Deserialize)]
132#[serde(rename_all = "camelCase")]
133pub struct SemanticConfigurationPrioritizedFields {
134    title_field: FieldDescriptor,
135    prioritized_content_fields: Vec<FieldDescriptor>,
136    prioritized_keywords_fields: Vec<FieldDescriptor>,
137}
138
139#[derive(Serialize, Deserialize)]
140#[serde(rename_all = "camelCase")]
141pub struct FieldDescriptor {
142    field_name: String,
143}
144
145#[derive(Serialize, Deserialize)]
146pub struct Similarity {
147    #[serde(rename = "@odata.type")]
148    odata_type: String,
149    b: Option<f64>,
150    k1: Option<f64>,
151}
152
153#[derive(Serialize, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct Suggester {
156    name: String,
157    search_mode: String,
158    source_fields: Vec<String>,
159}
160
161#[derive(Serialize, Deserialize)]
162pub struct VectorSearch {
163    profiles: Vec<Profile>,
164    algorithms: Vec<Algorithm>,
165    compressions: Vec<Compression>,
166    vectorizers: Vec<Vectorizer>,
167}
168
169#[derive(Serialize, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct Vectorizer {
172    name: String,
173    kind: String,
174    #[serde(rename = "azureOpenAIParameters")]
175    azure_open_ai_parameters: AzureOpenAiParameters,
176    custom_web_api_parameters: Option<serde_json::Value>,
177}
178
179#[derive(Serialize, Deserialize)]
180#[serde(rename_all = "camelCase")]
181pub struct AzureOpenAiParameters {
182    resource_uri: String,
183    deployment_id: String,
184    api_key: String,
185    model_name: String,
186    auth_identity: Option<serde_json::Value>,
187}
188
189#[derive(Serialize, Deserialize)]
190#[serde(rename_all = "camelCase")]
191pub struct Algorithm {
192    name: String,
193    kind: String,
194    hnsw_parameters: Option<HnswParameters>,
195    exhaustive_knn_parameters: Option<ExhaustiveKnnParameters>,
196}
197
198#[derive(Serialize, Deserialize)]
199pub struct ExhaustiveKnnParameters {
200    metric: String,
201}
202
203#[derive(Serialize, Deserialize)]
204#[serde(rename_all = "camelCase")]
205pub struct HnswParameters {
206    m: i64,
207    metric: String,
208    ef_construction: i64,
209    ef_search: i64,
210}
211
212#[derive(Serialize, Deserialize)]
213#[serde(rename_all = "camelCase")]
214pub struct Compression {
215    name: String,
216    kind: String,
217    scalar_quantization_parameters: Option<ScalarQuantizationParameters>,
218    rerank_with_original_vectors: bool,
219    default_oversampling: i64,
220}
221
222#[derive(Serialize, Deserialize)]
223#[serde(rename_all = "camelCase")]
224pub struct ScalarQuantizationParameters {
225    quantized_data_type: String,
226}
227
228#[derive(Serialize, Deserialize)]
229pub struct Profile {
230    name: String,
231    algorithm: String,
232    compression: Option<String>,
233    vectorizer: Option<String>,
234}
235
236pub async fn does_search_index_exist(
237    index_name: &str,
238    app_config: &ApplicationConfiguration,
239) -> anyhow::Result<bool> {
240    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
241        anyhow::anyhow!("Azure configuration is missing from the application configuration")
242    })?;
243
244    let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
245        anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
246    })?;
247
248    let mut url = search_config.search_endpoint.clone();
249    url.set_path(&format!("indexes('{}')", index_name));
250    url.set_query(Some(&format!("api-version={}", API_VERSION)));
251
252    let response = REQWEST_CLIENT
253        .get(url)
254        .header("Content-Type", "application/json")
255        .header("api-key", search_config.search_api_key.clone())
256        .send()
257        .await?;
258
259    if response.status().is_success() {
260        Ok(true)
261    } else if response.status() == 404 {
262        Ok(false)
263    } else {
264        let status = response.status();
265        let error_text = response.text().await?;
266        Err(anyhow::anyhow!(
267            "Error checking if index exists. Status: {}. Error: {}",
268            status,
269            error_text
270        ))
271    }
272}
273
274pub async fn create_search_index(
275    index_name: String,
276    app_config: &ApplicationConfiguration,
277) -> anyhow::Result<()> {
278    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
279        anyhow::anyhow!("Azure configuration is missing from the application configuration")
280    })?;
281
282    let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
283        anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
284    })?;
285
286    let fields = vec![
287        Field {
288            name: "chunk_id".to_string(),
289            field_type: "Edm.String".to_string(),
290            key: Some(true),
291            searchable: Some(true),
292            filterable: Some(true),
293            retrievable: Some(true),
294            stored: Some(true),
295            sortable: Some(false),
296            facetable: Some(false),
297            analyzer: Some("keyword".to_string()),
298            index_analyzer: None,
299            search_analyzer: None,
300            synonym_maps: Some(vec![]),
301            dimensions: None,
302            vector_search_profile: None,
303            vector_encoding: None,
304        },
305        Field {
306            name: "language".to_string(),
307            field_type: "Edm.String".to_string(),
308            key: Some(false),
309            searchable: Some(true),
310            filterable: Some(true),
311            retrievable: Some(true),
312            stored: Some(true),
313            sortable: Some(false),
314            facetable: Some(false),
315            analyzer: None,
316            index_analyzer: None,
317            search_analyzer: None,
318            synonym_maps: Some(vec![]),
319            dimensions: None,
320            vector_search_profile: None,
321            vector_encoding: None,
322        },
323        Field {
324            name: "parent_id".to_string(),
325            field_type: "Edm.String".to_string(),
326            key: Some(false),
327            searchable: Some(true),
328            filterable: Some(true),
329            retrievable: Some(true),
330            stored: Some(true),
331            sortable: Some(true),
332            facetable: Some(true),
333            analyzer: None,
334            index_analyzer: None,
335            search_analyzer: None,
336            synonym_maps: Some(vec![]),
337            dimensions: None,
338            vector_search_profile: None,
339            vector_encoding: None,
340        },
341        Field {
342            name: "chunk".to_string(),
343            field_type: "Edm.String".to_string(),
344            key: Some(false),
345            searchable: Some(true),
346            filterable: Some(false),
347            retrievable: Some(true),
348            stored: Some(true),
349            sortable: Some(false),
350            facetable: Some(false),
351            analyzer: Some("keyword".to_string()),
352            index_analyzer: None,
353            search_analyzer: None,
354            synonym_maps: Some(vec![]),
355            dimensions: None,
356            vector_search_profile: None,
357            vector_encoding: None,
358        },
359        Field {
360            name: "title".to_string(),
361            field_type: "Edm.String".to_string(),
362            key: Some(false),
363            searchable: Some(true),
364            filterable: Some(true),
365            retrievable: Some(true),
366            stored: Some(true),
367            sortable: Some(false),
368            facetable: Some(false),
369            analyzer: None,
370            index_analyzer: None,
371            search_analyzer: None,
372            synonym_maps: Some(vec![]),
373            dimensions: None,
374            vector_search_profile: None,
375            vector_encoding: None,
376        },
377        Field {
378            name: "url".to_string(),
379            field_type: "Edm.String".to_string(),
380            key: Some(false),
381            searchable: Some(false),
382            filterable: Some(true),
383            retrievable: Some(true),
384            stored: Some(true),
385            sortable: Some(false),
386            facetable: Some(false),
387            analyzer: None,
388            index_analyzer: None,
389            search_analyzer: None,
390            synonym_maps: Some(vec![]),
391            dimensions: None,
392            vector_search_profile: None,
393            vector_encoding: None,
394        },
395        Field {
396            name: "course_id".to_string(),
397            field_type: "Edm.String".to_string(),
398            key: Some(false),
399            searchable: Some(false),
400            filterable: Some(true),
401            retrievable: Some(true),
402            stored: Some(true),
403            sortable: Some(false),
404            facetable: Some(false),
405            analyzer: None,
406            index_analyzer: None,
407            search_analyzer: None,
408            synonym_maps: Some(vec![]),
409            dimensions: None,
410            vector_search_profile: None,
411            vector_encoding: None,
412        },
413        Field {
414            name: "text_vector".to_string(),
415            field_type: "Collection(Edm.Single)".to_string(),
416            key: Some(false),
417            searchable: Some(true),
418            filterable: Some(false),
419            retrievable: Some(true),
420            stored: Some(true),
421            sortable: Some(false),
422            facetable: Some(false),
423            analyzer: None,
424            index_analyzer: None,
425            search_analyzer: None,
426            synonym_maps: Some(vec![]),
427            dimensions: Some(1536),
428            vector_search_profile: Some(format!("{}-azureOpenAi-text-profile", index_name)),
429            vector_encoding: None,
430        },
431        Field {
432            name: "filepath".to_string(),
433            field_type: "Edm.String".to_string(),
434            key: Some(false),
435            searchable: Some(true),
436            filterable: Some(true),
437            retrievable: Some(true),
438            stored: Some(true),
439            sortable: Some(false),
440            facetable: Some(false),
441            analyzer: None,
442            index_analyzer: None,
443            search_analyzer: None,
444            synonym_maps: Some(vec![]),
445            dimensions: None,
446            vector_search_profile: None,
447            vector_encoding: None,
448        },
449        Field {
450            name: "chunk_context".to_string(),
451            field_type: "Edm.String".to_string(),
452            key: Some(false),
453            searchable: Some(true),
454            filterable: Some(false),
455            retrievable: Some(true),
456            stored: Some(true),
457            sortable: Some(false),
458            facetable: Some(false),
459            analyzer: None,
460            index_analyzer: None,
461            search_analyzer: None,
462            synonym_maps: Some(vec![]),
463            dimensions: None,
464            vector_search_profile: None,
465            vector_encoding: None,
466        },
467    ];
468
469    let index = NewIndex {
470        name: index_name.clone(),
471        fields,
472        scoring_profiles: vec![],
473        default_scoring_profile: None,
474        suggesters: vec![],
475        analyzers: vec![],
476        tokenizers: vec![],
477        token_filters: vec![],
478        char_filters: vec![],
479        cors_options: CorsOptions {
480            allowed_origins: vec!["*".to_string()],
481            max_age_in_seconds: 300,
482        },
483        encryption_key: None,
484        similarity: Similarity {
485            odata_type: "#Microsoft.Azure.Search.BM25Similarity".to_string(),
486            b: None,
487            k1: None,
488        },
489        semantic: Semantic {
490            default_configuration: format!("{}-semantic-configuration", index_name),
491            configurations: vec![SemanticConfiguration {
492                name: format!("{}-semantic-configuration", index_name),
493                prioritized_fields: SemanticConfigurationPrioritizedFields {
494                    title_field: FieldDescriptor {
495                        field_name: "title".to_string(),
496                    },
497                    prioritized_content_fields: vec![FieldDescriptor {
498                        field_name: "chunk".to_string(),
499                    }],
500                    prioritized_keywords_fields: vec![],
501                },
502            }],
503        },
504        vector_search: VectorSearch {
505            profiles: vec![Profile {
506                name: format!("{}-azureOpenAi-text-profile", index_name),
507                algorithm: format!("{}-algorithm", index_name),
508                vectorizer: Some(format!("{}-azureOpenAi-text-vectorizer", index_name)),
509                compression: None,
510            }],
511            algorithms: vec![Algorithm {
512                name: format!("{}-algorithm", index_name),
513                kind: "hnsw".to_string(),
514                hnsw_parameters: Some(HnswParameters {
515                    m: 4,
516                    metric: "cosine".to_string(),
517                    ef_construction: 400,
518                    ef_search: 500,
519                }),
520                exhaustive_knn_parameters: None,
521            }],
522            vectorizers: vec![Vectorizer {
523                name: format!("{}-azureOpenAi-text-vectorizer", index_name),
524                kind: "azureOpenAI".to_string(),
525                azure_open_ai_parameters: AzureOpenAiParameters {
526                    resource_uri: search_config.vectorizer_resource_uri.clone(),
527                    deployment_id: search_config.vectorizer_deployment_id.clone(),
528                    api_key: search_config.vectorizer_api_key.clone(),
529                    model_name: search_config.vectorizer_model_name.clone(),
530                    auth_identity: None,
531                },
532                custom_web_api_parameters: None,
533            }],
534            compressions: vec![],
535        },
536    };
537
538    let index_json = serde_json::to_string(&index)?;
539
540    let mut url = search_config.search_endpoint.clone();
541    url.set_path("/indexes");
542    url.set_query(Some(&format!("api-version={}", API_VERSION)));
543
544    let response = REQWEST_CLIENT
545        .post(url)
546        .header("Content-Type", "application/json")
547        .header("api-key", search_config.search_api_key.clone())
548        .body(index_json)
549        .send()
550        .await?;
551
552    // Check for a successful response
553    if response.status().is_success() {
554        println!("Index created successfully: {}", index_name);
555        Ok(())
556    } else {
557        let status = response.status();
558        let error_text = response.text().await?;
559        Err(anyhow::anyhow!(
560            "Failed to create index. Status: {}. Error: {}",
561            status,
562            error_text
563        ))
564    }
565}
566
567#[derive(Serialize, Deserialize)]
568pub struct IndexAction<T> {
569    #[serde(rename = "@search.action")]
570    pub search_action: String,
571    pub document: T,
572}
573
574#[derive(Serialize, Deserialize)]
575pub struct IndexBatch<T> {
576    pub value: Vec<IndexAction<T>>,
577}
578
579pub async fn add_documents_to_index<T>(
580    index_name: &str,
581    documents: Vec<T>,
582    app_config: &ApplicationConfiguration,
583) -> anyhow::Result<()>
584where
585    T: Serialize,
586{
587    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
588        anyhow::anyhow!("Azure configuration is missing from the application configuration")
589    })?;
590
591    let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
592        anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
593    })?;
594
595    let mut url = search_config.search_endpoint.clone();
596    url.set_path(&format!("indexes('{}')/docs/index", index_name));
597    url.set_query(Some(&format!("api-version={}", API_VERSION)));
598
599    let index_actions: anyhow::Result<Vec<IndexAction<String>>> = documents
600        .into_iter()
601        .map(|doc| {
602            serde_json::to_string(&doc)
603                .map(|document| IndexAction {
604                    search_action: "upload".to_string(),
605                    document,
606                })
607                .map_err(|e| anyhow::anyhow!("Failed to serialize document: {}", e))
608        })
609        .collect();
610    let index_actions = index_actions?;
611
612    let batch = IndexBatch {
613        value: index_actions,
614    };
615
616    let batch_json = serde_json::to_string(&batch)?;
617
618    let response = REQWEST_CLIENT
619        .post(url)
620        .header("Content-Type", "application/json")
621        .header("api-key", search_config.search_api_key.clone())
622        .body(batch_json)
623        .send()
624        .await?;
625
626    if response.status().is_success() {
627        println!("Documents added successfully to index: {}", index_name);
628        Ok(())
629    } else {
630        let status = response.status();
631        let error_text = response.text().await?;
632        Err(anyhow::anyhow!(
633            "Failed to add documents to index. Status: {}. Error: {}",
634            status,
635            error_text
636        ))
637    }
638}