from audioop import add
from typing import Any, List, Union
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn import preprocessing as pp
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import sklearn.metrics as metrics
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from .utils import *
[docs]class ML:
def __init__(self,
dtrs: np.ndarray,
imgfiles: List[str],
) -> None:
"""initialize machine learning analysis using DTRs
Args:
dtrs (np.ndarray): M-dimensional DTRs for N images (NxM array)
imgfiles (List[str]): List of N image files.
"""
self.dtrs = dtrs
self.imgfiles = imgfiles
[docs] def fit_eval(self,
y: Union[list, np.ndarray],
cases: Union[list, np.ndarray],
additional_features: np.ndarray = None,
min_samples: int = 5,
show: bool = True,
test_size: Union[float, int] = 0.25,
) -> Any:
"""Logistic regression analysis.
Args:
y (Union[list, np.ndarray]): Target variable.
cases (Union[list, np.ndarray]): Case IDs (used as group).
additional_features (np.ndarray, optional): Additional features used for the classification. It MUST be the numerical arrays. If it is a categorical variable, please use categorical encoders. Defaults to None.
min_samples (int, optional): Minimum number of cases analyzed in a target. Targets below the value will be removed. Defaults to 5.
show (bool, optional): Show confusion matrix or ROC curve. Defaults to True.
test_size (Union[int, float], optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. Defaults to 0.25.
Returns:
Any: AUROC (for binary classification) or confusion matrix (for multiclass classification).
"""
#count cases for each class
labels = np.unique(y)
used_index = []
for l in labels:
target_index = np.where(np.array(y) == l)[0]
count = len(np.unique(np.array(cases)[np.array(y) == l]))
if count < min_samples:
print(f'class {l} is not analyzed (only {count} cases)')
else:
used_index.extend(list(target_index))
dtrs2 = self.dtrs[used_index,:]
imgfiles2 = np.array(self.imgfiles)[used_index]
if additional_features is not None:
if len(additional_features.shape) == 1:
additional_features = np.expand_dims(additional_features, axis=1)
dtrs2 = np.concatenate([dtrs2, additional_features],axis=1)
y = np.array(y)[used_index]
cases = np.array(cases)[used_index]
labels = list(np.unique(y))
print(f'labels: {labels}')
if len(labels) > 2:
mode = 'multi'
print ("Muticlass => SVM")
model = SVC(kernel = 'linear', C = 1)
elif len(labels) == 2:
mode = 'binary'
print ("Binary => Logistic Regression")
model = LogisticRegression(solver='liblinear')
else:
raise Exception(f'invalid number of classes ({labels})')
#split cases
u_cases = np.unique(cases)
y_ucases = [y[cases==c][0] for c in u_cases]
train_cases, test_cases = train_test_split(u_cases,
test_size=test_size,
stratify=y_ucases,
random_state=0)
X_train = np.vstack([x for i, x in enumerate(dtrs2) if cases[i] in train_cases])
X_test = np.vstack([x for i, x in enumerate(dtrs2) if cases[i] in test_cases])
y_train = [x for i, x in enumerate(y) if cases[i] in train_cases]
y_test = [x for i, x in enumerate(y) if cases[i] in test_cases]
#img_train = [x for i, x in enumerate(imgfiles2) if cases[i] in train_cases]
img_test = [x for i, x in enumerate(imgfiles2) if cases[i] in test_cases]
model.fit(X_train, y_train)
self.model = model
self.ml_data = {
'train_case': train_cases,
'test_case': test_cases,
'X_train': X_train,
'X_test': X_test,
'y_train': y_train,
'y_test': y_test,
}
if mode == 'multi':
y_pred = self.model.predict(X_test)
self.ml_data['y_pred'] = y_pred
conf_mat = confusion_matrix(y_test, y_pred)
conf_mat_df = pd.DataFrame(data=conf_mat,
index=labels,
columns=labels)
if show:
fig, ax = plt.subplots()
im = ax.matshow(conf_mat, cmap=plt.cm.Blues, alpha=0.3)
for i in range(conf_mat.shape[0]):
for j in range(conf_mat.shape[1]):
ax.text(x=j, y=i, s=f"{conf_mat[i, j]}({conf_mat[i, j]*100/np.sum(conf_mat[i]):.1f}%)", va='center', ha='center')
fig.colorbar(im)
tick_marks = np.arange(len(labels))
plt.tick_params(axis="x", bottom=False, top=False, labelbottom=True, labeltop=False)
plt.tick_params(axis="y", left=False)
plt.xticks(tick_marks, labels, rotation=45)
plt.yticks(tick_marks, labels)
plt.xlabel("Prediction", fontsize=26, rotation=0)
plt.ylabel("Ground Truth", fontsize=26)
return conf_mat_df, {'imgfiles_test':img_test, 'y_pred':y_pred, 'y_test':y_test}
else:
pos_index = 1
probs = self.model.predict_proba(X_test)
preds = probs[:,pos_index]
self.ml_data['y_pred'] = preds
fpr, tpr, _ = metrics.roc_curve(y_test, preds,
pos_label = self.model.classes_[pos_index])
roc_auc = metrics.auc(fpr, tpr)
if show:
plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()
return roc_auc, {'imgfiles_test':img_test, 'y_pred':preds, 'y_test':y_test}
[docs] def get_model(self,
) -> Any:
"""return the trained supervised model.
Returns:
Any: the trained model.
"""
return self.model
[docs] def get_result(self,
) -> Any:
"""return the training and test data and prediction.
Returns:
Any: Dictionary containing train/test case/X/y and the prediction.
"""
[docs] def clustering(self,
method: str = 'bayes_gmm',
n_components: int = 10,
show: bool = False,
) -> List[int]:
"""Clustering of dtrs.
Args:
method (str, optional): Clustering algorithm. Defaults to 'bayes_gmm'.
n_components (int, optional): Number of (maximum) clusters. Defaults to 10.
show (bool, optional): Show representative images. Defaults to False.
Returns:
List[int]: Cluster labels.
"""
if method == 'bayes_gmm':
from sklearn.mixture import BayesianGaussianMixture
model = BayesianGaussianMixture(n_components=n_components,
random_state=42)
else:
raise Exception(f'invalid clustering algorithm: {method}')
cluster_label = model.fit_predict(self.dtrs)
if show:
medoid_dict = get_medoid(self.dtrs, cluster_label)
imgfiles_medoid = [self.imgfiles[medoid_dict[c]] for c in sorted(np.unique(cluster_label))]
imgcats(imgfiles_medoid,
labels = sorted(np.unique(cluster_label)),
nrows = 3)
return cluster_label