AI Module (stx.ai)
Machine learning utilities for training, classification, and metrics with PyTorch and scikit-learn.
Quick Reference
import scitex as stx
# Training utilities
from scitex.ai import LearningCurveLogger, EarlyStopping
logger = LearningCurveLogger()
stopper = EarlyStopping(patience=10, direction="minimize")
for epoch in range(100):
# ... training loop ...
logger({"loss": loss, "acc": acc}, step="Training")
if stopper(val_loss, {"model": model_path}, epoch):
break
logger.plot_learning_curves(spath="curves.png")
# Classification
from scitex.ai import ClassificationReporter, Classifier
clf = Classifier()("SVC")
reporter = ClassificationReporter(output_dir="./results")
reporter.calculate_metrics(y_true, y_pred, y_proba)
reporter.save_summary()
Training
LearningCurveLogger– Track and visualize training/validation/test metrics across epochsEarlyStopping– Monitor validation metrics and stop when improvement plateaus
Classification
ClassificationReporter– Unified reporter for single/multi-task classification (balanced accuracy, MCC, ROC-AUC, confusion matrices)Classifier– Factory for scikit-learn classifiers (SVC, KNN, Logistic Regression, AdaBoost, …)CrossValidationExperiment– Cross-validation framework
Metrics
Standardized calc_* functions:
calc_bacc– Balanced accuracycalc_mcc– Matthews Correlation Coefficientcalc_conf_mat– Confusion matrixcalc_roc_auc– ROC-AUC scorecalc_pre_rec_auc– Precision-Recall AUCcalc_feature_importance– Feature importance scores
Visualization
plot_learning_curve– Training/validation curvesstx_conf_mat– Confusion matrix heatmapplot_roc_curve– ROC curveplot_pre_rec_curve– Precision-Recall curveplot_feature_importance– Feature importance bar plots
Other
MultiTaskLoss– Multi-task learning loss weightingget_optimizer/set_optimizer– Optimizer managementGenAI– Generative AI wrapper (lazy-loaded)Clustering:
pca,umap