Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions applications/offline/SEPSIS_RNN_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Privacy-Preserving Sepsis Prediction using Simple RNN

## 1. Project Overview
This project builds a secure Machine Learning pipeline using the **Sequre framework**. By using Multi-Party Computation (MPC), the code ensures complete data privacy while the model runs. While the specific test case focuses on predicting Sepsis, the main goal of this addition is to build and test Recurrent Neural Network (RNN) features inside Sequre's encrypted environment.

## 2. Dataset & Preprocessing
The model uses the **MIMIC-III** clinical database for testing.
* **Data Extraction:** Patient records were filtered into Sepsis and Non-Sepsis groups.
* **Feature Engineering:** The data was formatted into sequential time-steps to represent a 48-hour patient observation window.
* **Weight Generation:** Model weights were trained, extracted, and saved as CSV matrices (`Wx.csv`, `Wh.csv`, `b.csv`, `Wy.csv`, `by.csv`) so they could be loaded into Sequre.

## 3. Model Architecture: Simple RNN
A standard Simple RNN was built using the `@sequre` decorator, which allows math operations on securely shared tensors.
* **Sequence Length:** 48 time-steps.
* **Hidden State:** Starts as a 1x64 zeros tensor and updates sequentially.
* **Recurrence Logic:** At each time step `t`, the hidden state updates using the input data, the memory from the previous step, and the bias.

## 4. Sequre Framework Implementation Details
To manage the limits of fixed-point math inside the framework, specific Sequre standard library functions were added:
* **Data Encryption:** Input data and model weights are encrypted across the compute parties (CP0, CP1, CP2) right at initialization.
* **Overflow Management:** The `clip` function from `sequre.stdlib.builtin` is applied to the raw hidden state tensors. This stops the fixed-point numbers from overflowing and crashing during the 48-step loop.
* **Activation Function:** The `chebyshev_sigmoid` function is used as a secure replacement for the standard sigmoid activation.

## 5. Conclusion
This code shows how to integrate sequential models (RNNs) into the Sequre framework. It creates a baseline for processing sensitive, time-series data (like health records) while keeping strict cryptographic privacy between all computing parties.
83 changes: 83 additions & 0 deletions applications/offline/sequre_sepsis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import random
from numpy.create import array, zeros
from sequre import local, Sharetensor as Stensor, sequre

def load_csv(filepath: str):
data = []
with open(filepath, 'r') as f:
for line in f:
if line.strip():
row = [float(x) for x in line.strip().split(',')]
data.append(row)
return array(data)

# FIX: New function to flatten tall bias columns into flat rows
def load_bias_as_row(filepath: str):
single_row = []
with open(filepath, 'r') as f:
for line in f:
if line.strip():
# Take the first number and put it in our single row
single_row.append(float(line.strip().split(',')[0]))
# Wrap it in brackets so it becomes [1, 64] instead of [64, 1]
return array([single_row])

@sequre
def secure_rnn(mpc, x_enc_list, Wx_enc, Wh_enc, b_enc, Wy_enc, by_enc):
h = Wx_enc.zeros((1, 64))
for t in range(48):
xt = x_enc_list[t]

input_part = xt @ Wx_enc
mem_part = h @ Wh_enc
# Now h (1x64) and b_enc (1x64) match perfectly!
h = input_part + mem_part + b_enc

logits = h @ Wy_enc + by_enc
return logits.reveal(mpc)

@local
def run_sepsis_local(mpc):
print(f"CP{mpc.pid}: Loading Sepsis Data & Weights...")
try:
Wx_raw = load_csv("Wx.csv")
Wh_raw = load_csv("Wh.csv")
b_raw = load_bias_as_row("b_rnn.csv") # Used the new shape fix
Wy_raw = load_csv("Wy.csv")
by_raw = load_bias_as_row("by.csv") # Used the new shape fix
except Exception as e:
print(f"Error loading files: {e}")
return

try:
raw_data = []
with open("patient_data.csv", 'r') as f:
for line in f:
if line.strip():
raw_data.append([float(x) for x in line.strip().split(',')])
except:
print("Warning: patient_data.csv not found, using random patient.")
raw_data = [[random.random() for _ in range(6)] for _ in range(48)]

print(f"CP{mpc.pid}: Encrypting...")
Wx_enc = Stensor.enc(mpc, Wx_raw)
Wh_enc = Stensor.enc(mpc, Wh_raw)
b_enc = Stensor.enc(mpc, b_raw)
Wy_enc = Stensor.enc(mpc, Wy_raw)
by_enc = Stensor.enc(mpc, by_raw)

x_enc_list = []
for i in range(48):
row_2d = array([raw_data[i]])
x_enc_list.append(Stensor.enc(mpc, row_2d))

print(f"CP{mpc.pid}: Running Secure RNN...")
final_score = secure_rnn(mpc, x_enc_list, Wx_enc, Wh_enc, b_enc, Wy_enc, by_enc)

print(f"CP{mpc.pid}: FINAL SEPSIS SCORE: {final_score}")
if final_score[0][0] > 0.0:
print(f"CP{mpc.pid}: PREDICTION: SEPSIS DETECTED")
else:
print(f"CP{mpc.pid}: PREDICTION: HEALTHY")

run_sepsis_local()