-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathclt_plot.py
More file actions
25 lines (20 loc) · 785 Bytes
/
clt_plot.py
File metadata and controls
25 lines (20 loc) · 785 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
from sklearn.datasets import load_boston
import deeprob.spn.structure as spn
if __name__ == '__main__':
# Load the boston dataset and binarize it
data, _ = load_boston(return_X_y=True)
avg_features = np.mean(data, axis=0)
data = (data < avg_features).astype(np.float32)
n_samples, n_features = data.shape
# Instantiate the random state
random_state = np.random.RandomState(42)
# Fit a binary CLT
scope = list(range(n_features))
domain = [[0, 1]] * n_features
clt = spn.BinaryCLT(scope)
clt.fit(data, domain, alpha=0.1, random_state=random_state)
# Plot the CLT
clt_filename = 'clt-bboston.svg'
print("Plotting the learnt CLT to {} ...".format(clt_filename))
spn.plot_binary_clt(clt, clt_filename)