超参数调整
scikit-learn中定义了一种名为estimator(评估器)的对象,estimator主要用于对模型进行评估和解码。所有的estimator对象都必须拥有.fit()
方法,并且提供.set_params()
和.get_params()
方法。基本上scikit-learn中分类、回归、聚类等算法模型都是继承自base.BaseEstimator
类,所以基本上都是具有这些必备方法的。
超参数调整的一个主要途径就是通过不同超参数的搭配来获得最高的交叉验证评分。在estimator中可以通过.get_params()
方法来获取模型中的全部参数。scikit-learn根据选择参数搭配的方法提供了两种参数值搭配寻找方法:GridSearchCV
类和RandomizeSearchCV
类。
GridSearchCV
类采用穷举法,从给定的参数值中进行穷举搭配测试,最后从中选择出一组评分最好的搭配组合。GridSearchCV
类接受一个字典列表作为参数值来源,其中字典键为参数名,值为可使用的参数值,每一个字典表示一种需要探索的超参数组合空间。以下给出一个使用GridSearchCV
类进行参数搜索的示例。
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report
from sklearn.svm import SVC
digits = datasets.load_digits()
n_samples = len(digits.images)
x = digits.images.reshape((n_samples, -1))
y = digits.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5, random_state=0)
tuned_parameters = [
{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4], 'C': '[1, 10, 100, 1000]},
{'kernel': ['linear'], 'C': [1, 10, 100, 1000]}
]
scores = ['precision', 'recall']
for score in scores:
clf = GridSearchCV(SVC(), tuned_parameters, cv=5, scoring=f'{score}_marco')
clf.fit(x_train, y_train)
print(clf.best_params_)
y_true, y_pred = y_test, clf.predict(x_test)
print(classification_report(y_true, y_pred))
构造函数中的参数cv
可以接受多种值,但总起来是用来定义交叉验证的分类。当给定整型值时,表示使用K折验证的折数,不传值采用默认值,K折为3。当给定一个可以抛出(train, test)
结构的生成器时,将会按照生成器抛出的元素索引数组拆分样本。或者还可以给定一个拆分器(CV Splitter)实例,拆分器类可在sklearn.model_selection
包中找到。
与GridSearchCV
类不同的是,RandomizeSearchCV
类不需要预先对可能的参数值分配空间,只需要将所有可能的参数与值列成一个字典即可,注意,不是字典列表。
寻找得到的最佳参数值可以通过模型的.best_params_
属性获得。