./split.py

./split.py#

import numpy as np

def split_data(X, Y, proportions, rng = None):
    """Split input and output into 3 subsets for ML model.

    Arguments
    =========
    X, Y:        ndarrays where rows are number of observations
                    (both arrays have identical number of rows)
    proportions: list with decimal fraction of original data defining
                 allocation into three parts (train, validate, test sets,
                 respectively). The list is len(proportions)=3, and
                 contains floats that should sum to 1.0.
    rng:         numpy random generator instance for reproducibility. If None,
                 a new generator is created without a fixed seed.

    Returns
    =======
    X_train, X_val, X_test, Y_train, Y_val, Y_test:
     6 ndarrays (3 splits each for input and output), where the number of
     columns corresponds to the original input and output (respectively)
     and the sum of the number of rows is equal to the rows of the original
     input/output.
    """
    assert len(proportions) == 3, "Three proportions must be provided"
    assert np.isclose(sum(proportions),1.0), "Sum of proportions should be one"
    assert len(X) == len(Y), "X and Y arrays must have same dimensions"

    # Shuffle data using random permutation of indices
    if rng is None:
        rng = np.random.default_rng()
    indices = rng.permutation(len(X))

    # Create shuffled training, validation and test sets
    train_prop = proportions[0]
    val_prop = proportions[1]
    
    train_end = int(train_prop*len(X))
    val_end = int(val_prop*len(X)) + train_end
    
    X_train, X_val, X_test = (X[indices[:train_end]],
                              X[indices[train_end:val_end]],
                              X[indices[val_end:]])
    Y_train, Y_val, Y_test = (Y[indices[:train_end]],
                              Y[indices[train_end:val_end]],
                              Y[indices[val_end:]])

    assert (len(X_train) + len(X_val) + len(X_test)) == len(X), "Generated datasets don't have same accumulated length as original"
    
    return X_train, X_val, X_test, Y_train, Y_val, Y_test