diff --git a/applications/offline/SEPSIS_RNN_README.md b/applications/offline/SEPSIS_RNN_README.md new file mode 100755 index 00000000..d266c8ec --- /dev/null +++ b/applications/offline/SEPSIS_RNN_README.md @@ -0,0 +1,25 @@ +# Privacy-Preserving Sepsis Prediction using Simple RNN + +## 1. Project Overview +This project builds a secure Machine Learning pipeline using the **Sequre framework**. By using Multi-Party Computation (MPC), the code ensures complete data privacy while the model runs. While the specific test case focuses on predicting Sepsis, the main goal of this addition is to build and test Recurrent Neural Network (RNN) features inside Sequre's encrypted environment. + +## 2. Dataset & Preprocessing +The model uses the **MIMIC-III** clinical database for testing. +* **Data Extraction:** Patient records were filtered into Sepsis and Non-Sepsis groups. +* **Feature Engineering:** The data was formatted into sequential time-steps to represent a 48-hour patient observation window. +* **Weight Generation:** Model weights were trained, extracted, and saved as CSV matrices (`Wx.csv`, `Wh.csv`, `b.csv`, `Wy.csv`, `by.csv`) so they could be loaded into Sequre. + +## 3. Model Architecture: Simple RNN +A standard Simple RNN was built using the `@sequre` decorator, which allows math operations on securely shared tensors. +* **Sequence Length:** 48 time-steps. +* **Hidden State:** Starts as a 1x64 zeros tensor and updates sequentially. +* **Recurrence Logic:** At each time step `t`, the hidden state updates using the input data, the memory from the previous step, and the bias. + +## 4. Sequre Framework Implementation Details +To manage the limits of fixed-point math inside the framework, specific Sequre standard library functions were added: +* **Data Encryption:** Input data and model weights are encrypted across the compute parties (CP0, CP1, CP2) right at initialization. +* **Overflow Management:** The `clip` function from `sequre.stdlib.builtin` is applied to the raw hidden state tensors. This stops the fixed-point numbers from overflowing and crashing during the 48-step loop. +* **Activation Function:** The `chebyshev_sigmoid` function is used as a secure replacement for the standard sigmoid activation. + +## 5. Conclusion +This code shows how to integrate sequential models (RNNs) into the Sequre framework. It creates a baseline for processing sensitive, time-series data (like health records) while keeping strict cryptographic privacy between all computing parties. \ No newline at end of file diff --git a/applications/offline/sequre_sepsis.py b/applications/offline/sequre_sepsis.py new file mode 100644 index 00000000..4304773b --- /dev/null +++ b/applications/offline/sequre_sepsis.py @@ -0,0 +1,83 @@ +import random +from numpy.create import array, zeros +from sequre import local, Sharetensor as Stensor, sequre + +def load_csv(filepath: str): + data = [] + with open(filepath, 'r') as f: + for line in f: + if line.strip(): + row = [float(x) for x in line.strip().split(',')] + data.append(row) + return array(data) + +# FIX: New function to flatten tall bias columns into flat rows +def load_bias_as_row(filepath: str): + single_row = [] + with open(filepath, 'r') as f: + for line in f: + if line.strip(): + # Take the first number and put it in our single row + single_row.append(float(line.strip().split(',')[0])) + # Wrap it in brackets so it becomes [1, 64] instead of [64, 1] + return array([single_row]) + +@sequre +def secure_rnn(mpc, x_enc_list, Wx_enc, Wh_enc, b_enc, Wy_enc, by_enc): + h = Wx_enc.zeros((1, 64)) + for t in range(48): + xt = x_enc_list[t] + + input_part = xt @ Wx_enc + mem_part = h @ Wh_enc + # Now h (1x64) and b_enc (1x64) match perfectly! + h = input_part + mem_part + b_enc + + logits = h @ Wy_enc + by_enc + return logits.reveal(mpc) + +@local +def run_sepsis_local(mpc): + print(f"CP{mpc.pid}: Loading Sepsis Data & Weights...") + try: + Wx_raw = load_csv("Wx.csv") + Wh_raw = load_csv("Wh.csv") + b_raw = load_bias_as_row("b_rnn.csv") # Used the new shape fix + Wy_raw = load_csv("Wy.csv") + by_raw = load_bias_as_row("by.csv") # Used the new shape fix + except Exception as e: + print(f"Error loading files: {e}") + return + + try: + raw_data = [] + with open("patient_data.csv", 'r') as f: + for line in f: + if line.strip(): + raw_data.append([float(x) for x in line.strip().split(',')]) + except: + print("Warning: patient_data.csv not found, using random patient.") + raw_data = [[random.random() for _ in range(6)] for _ in range(48)] + + print(f"CP{mpc.pid}: Encrypting...") + Wx_enc = Stensor.enc(mpc, Wx_raw) + Wh_enc = Stensor.enc(mpc, Wh_raw) + b_enc = Stensor.enc(mpc, b_raw) + Wy_enc = Stensor.enc(mpc, Wy_raw) + by_enc = Stensor.enc(mpc, by_raw) + + x_enc_list = [] + for i in range(48): + row_2d = array([raw_data[i]]) + x_enc_list.append(Stensor.enc(mpc, row_2d)) + + print(f"CP{mpc.pid}: Running Secure RNN...") + final_score = secure_rnn(mpc, x_enc_list, Wx_enc, Wh_enc, b_enc, Wy_enc, by_enc) + + print(f"CP{mpc.pid}: FINAL SEPSIS SCORE: {final_score}") + if final_score[0][0] > 0.0: + print(f"CP{mpc.pid}: PREDICTION: SEPSIS DETECTED") + else: + print(f"CP{mpc.pid}: PREDICTION: HEALTHY") + +run_sepsis_local()