Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions gwlearn/tests/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,72 @@ def test_against_mgwr():
assert_almost_equal(gwlr.aicc_, res.aicc, decimal=0)
assert_almost_equal(gwlr.effective_df_, res.ENP)
assert_almost_equal(gwlr.log_likelihood_, res.llf)


def test_gwlogistic_predict(sample_data):
# Unpack sample dataset (features, target, spatial geometry)
X, y, geometry = sample_data

# Initialize model with keep_models=True (required for prediction)
model = GWLogisticRegression(
bandwidth=10,
fixed=False,
keep_models=True,
max_iter=1000,
)

# Fit geographically weighted logistic model
model.fit(X, y, geometry=geometry)

# Generate predictions using fitted local models
preds = model.predict(X, geometry=geometry)

# Assert: number of predictions matches number of input samples
assert len(preds) == len(X)

# Assert: predictions are valid binary outputs (0/1 or True/False)
assert set(np.unique(preds.dropna())).issubset({0, 1})

# Ensure predictions are not all NaN
assert not preds.isna().all()


def test_gwlogistic_predict_proba(sample_data):
# Unpack sample dataset
X, y, geometry = sample_data

# Initialize model with keep_models=True for probability prediction
model = GWLogisticRegression(
bandwidth=10,
fixed=False,
keep_models=True,
max_iter=1000,
)

# Fit model
model.fit(X, y, geometry=geometry)

# Get class probability predictions
proba = model.predict_proba(X, geometry=geometry)

# Assert: number of rows equals number of samples
assert proba.shape[0] == len(X)

# Assert: binary classification → exactly 2 probability columns
assert proba.shape[1] == 2


def test_predict_requires_keep_models(sample_data):
# Unpack sample dataset
X, y, geometry = sample_data

# Initialize model WITHOUT storing local models
model = GWLogisticRegression(
bandwidth=10, fixed=False, keep_models=False, max_iter=1000
)
# Fit model
model.fit(X, y, geometry=geometry)

# prediction requires stored local models → should fail
with pytest.raises(AttributeError, match="_local_models"):
model.predict(X, geometry=geometry)