Coverage for addmo/s3_model_tuning/scoring/validation_splitting/abstract_splitter.py: 36%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import numpy as np 

2 

3import pandas as pd 

4 

5from sklearn.model_selection import BaseCrossValidator 

6 

7 

8class AbstractSplitter(BaseCrossValidator): 

9 """ 

10 Generate a splitter that is compatible with scikit-learn cross-validation tools. 

11 

12 Split: On a CV splitter (not an estimator), this method accepts parameters (X, y, groups), 

13 where all may be optional, and returns an iterator over (train_idx, test_idx) pairs. Each of { 

14 train,test}_idx is a 1d integer array, with values from 0 from X.shape[0] - 1 of any length, 

15 such that no values appear in both some train_idx and its corresponding test_idx. 

16 

17 cross-validation generator: A non-estimator family of classes used to split a dataset into a 

18 sequence of train and test portions (see Cross-validation: evaluating estimator performance), 

19 by providing split and get_n_splits methods. Note that unlike estimators, these do not have fit 

20 methods and do not provide set_params or get_params. Parameter validation may be performed in 

21 __init__.""" 

22 

23 def __init__(self, **kwargs): 

24 self.kwargs = kwargs 

25 

26 def split(self, X: pd.DataFrame = None, y: pd.Series = None, groups=None): 

27 """Generate indices to split system_data into training and test sets. This dummy implementation 

28 is copied from scikit-learn. It ensures that the train set always contains the remaining 

29 indices compared to the test set. If you don't want this behavior, you can override this 

30 method in your custom splitter. Otherwise, I recommend keeping this method as it is and 

31 making changes to the _iter_test_indices method. 

32 """ 

33 # Convert from numpy to pandas, as custom splitters may work with pandas 

34 if isinstance(X, np.ndarray): 

35 X = pd.DataFrame(X) 

36 if isinstance(y, np.ndarray): 

37 y = pd.Series(y) 

38 

39 indices = np.arange(len(X)) 

40 for test_index in self._iter_test_masks(X, y, groups): 

41 train_index = indices[np.logical_not(test_index)] 

42 test_index = indices[test_index] 

43 yield train_index, test_index 

44 

45 def get_n_splits(self, X: pd.DataFrame = None, y: pd.Series = None, groups=None): 

46 """Return the number of splitting iterations in the cross-validator.""" 

47 raise NotImplementedError 

48 

49 def _iter_test_masks( 

50 self, X: pd.DataFrame = None, y: pd.Series = None, groups=None 

51 ): 

52 """Generates boolean masks corresponding to test sets. 

53 

54 By default, delegates to _iter_test_indices(X, y, groups) 

55 """ 

56 for test_index in self._iter_test_indices(X, y, groups): 

57 test_mask = np.zeros(len(X), dtype=bool) 

58 test_mask[test_index] = True 

59 yield test_mask 

60 

61 def _iter_test_indices( 

62 self, X: pd.DataFrame = None, y=None, groups=None 

63 ) -> np.ndarray: 

64 """Generates integer indices corresponding to test sets.""" 

65 raise NotImplementedError