百闻不如一练:可视化调试模型超参数 !

反射极昼
• 阅读 132

以下使用scikit-learn中数据集进行分享。

如果选用随机森林作为最终的模型,那么找出它的最佳参数可能有1000多种组合的可能,你可以使用使用穷尽的网格搜索(Exhaustive Grid Seaarch)方法,但时间成本将会很高(运行很久...),或者使用随机搜索(Randomized Search)方法,仅分析超参数集合中的子集合。

该例子以手写数据集为例,使用支持向量机的方法对数据进行建模,然后调用scikit-learn中validation_surve方法将模型交叉验证的结果进行可视化。需要注意的是,在使用validation_curve方法时,只能验证一个超参数与模型训练集和验证集得分的关系(即二维的可视化),而不能实现多参数与得分间关系的可视化。以下搜索的参数是gamma,需要给定参数范围,用param_range进行传递,评分策略用scoring参数进行传递。其代码示例如下所示:

print(__doc__)

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import load_digits
from sklearn.svm import SVC
from sklearn.model_selection import validation_curve

X, y = load_digits(return_X_y=True)

param_range = np.logspace(-6, -1, 5)
train_scores, test_scores = validation_curve(
   SVC(), X, y, param_name="gamma", param_range=param_range,
   scoring="accuracy", n_jobs=1)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)

plt.title("Validation Curve with SVM")
plt.xlabel(r"$\\gamma$")
plt.ylabel("Score")
plt.ylim(0.0, 1.1)
lw = 2
plt.semilogx(param_range, train_scores_mean, label="Training score",
            color="darkorange", lw=lw)
plt.fill_between(param_range, train_scores_mean - train_scores_std,
                train_scores_mean + train_scores_std, alpha=0.2,
                color="darkorange", lw=lw)
plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",
            color="navy", lw=lw)
plt.fill_between(param_range, test_scores_mean - test_scores_std,
                test_scores_mean + test_scores_std, alpha=0.2,
                color="navy", lw=lw)
plt.legend(loc="best")
plt.show();

代码中:

X, y = load_digits(return_X_y=True)

等价于


digits = load_digits()
X_digits = digits.data
y_digits = digits.target  

以下是支持向量机的验证曲线,调节的超参数gamma共有5个值,每一个点的分数是五折交叉验证(cv=5)的均值。

百闻不如一练:可视化调试模型超参数 !

当想看模型多个超参数与模型评分之间的关系时,使用scikit-learn中validation curve就难以实现,因此可以考虑绘制三维坐标图。

主要用plotly的库绘制3D Scatter(3d散点图)。以下的例子使用scikit-learn中的莺尾花的数据集(iris)。以下例子选用随机森林模型(RandomForestRegressor),利用scikit-learn中的GridSearchCV方法调试最佳超参(tuning hyper-parameters),分别设置超参数"n_estimators","max_features","min_samples_split"的参数范围,详见代码如下:

import numpy as np
from sklearn.model_selection import validation_curve
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestRegressor  
from plotly.offline import iplot  
from plotly.graph_objs as go

model = RandomForestRegressor(n_jobs=-1, random_state=2, verbose=2)

grid = {'n_estimators': [10,110,200],
       'max_features': [0.05, 0.07, 0.09, 0.11, 0.13],
       'min_samples_split': [2, 3, 5, 8]}

rf_gridsearch = GridSearchCV(estimator=model, param_grid=grid, n_jobs=4, cv=5, verbose=2, return_train_score=True)

rf_gridsearch.fit(X, y)

and after some hours...

df_gridsearch = pd.DataFrame(rf_gridsearch.cv_results_)    

trace = go.Scatter3d(
   x=df_gridsearch['param_max_features'],
   y=df_gridsearch['param_n_estimators'],
   z=df_gridsearch['param_min_samples_split'],
   mode='markers',
   marker=dict(
       # size=df_gridsearch.mean_fit_time ** (1 / 3),
       size = 10,
       color=df_gridsearch.mean_test_score,
       opacity=0.99,
       colorscale='Viridis',
       colorbar=dict(title = 'Test score'),
       line=dict(color='rgb(140, 140, 170)'),
  ),
   text=df_gridsearch.Text,
   hoverinfo='text'
)

data = [trace]
layout = go.Layout(
   title='3D visualization of the grid search results',
   margin=dict(
       l=30,
       r=30,
       b=30,
       t=30
  ),

   scene = dict(
       xaxis = dict(
           title='max_features',
           nticks=10
      ),
       yaxis = dict(
           title='n_estimators',
      ),
       zaxis = dict(
           title='min_samples_split',

      ),
  ),

)

fig = go.Figure(data=data, layout=layout)
iplot(fig)  

其运行结果如果,是一个三维散点图(3D Scatter)。

百闻不如一练:可视化调试模型超参数 !

可以看到颜色越浅,分数越高。n_estimators(子估计器)越多,分数越高,max_features的变化对模型分数的影响较小,在图中看不到变化,min_samples_split的个数并不是越高越好,但与模型分数并不呈单调关系,在min_samples_split取2时(此时,其它条件不变),模型分数最高。

除了使用scikit-learn中validation curve绘制超参数与得分的可视化,还可以利用seaborn库中heatmap方法来实现两个超参数之间的关系图,如下代码示例:

import seaborn as sns

title = '''Maximum R2 score on test set VS
max_features, min_samples_split'''  
sns.heatmap(max_scores.mean_test_score, annot=True, fmt='.4g');
plt.title(title);
plt.savefig("heatmap_test.png", dpi = 300);  

import seaborn as sns

title = '''Maximum R2 score on train set VS
max_features, min_samples_split'''
sns.heatmap(max_scores.mean_train_score, annot=True, fmt='.4g');
plt.title(title);
plt.savefig("heatmap_train.png", dpi = 300);  

max_features和min_samples与模型得分关系的可视化如下图所示(分别为网格搜索中测试集和训练集的得分):

百闻不如一练:可视化调试模型超参数 !

百闻不如一练:可视化调试模型超参数 !

由于一般人很难迅速的在大量数据中找到隐藏的关系,因此,可以考虑绘图,将数据关系以图表的形式,清晰的显现出来。

综上,当关注单个超参数的学习曲线时,可以使用scikit-learn中validation curve,找到拐点,作为模型的最佳参数。

当关注两个超参数的共同变化对模型分数的影响时,可以使用seaborn库中的heatmap方法,制作“热图”,以找到超参数协同变化对分数影响的趋势。

当关注三个参数的协同变化与模型得分的关系时,可以使用poltly库中的iplot和go方法,绘制3d散点图(3D Scatter),将其协同变化对模型分数的影响展现在高维图中。

(1)获取更多优质内容及精彩资讯,可前往:https://www.cda.cn/?seo

(2)了解更多数据领域的优质课程:

百闻不如一练:可视化调试模型超参数 !

点赞
收藏
评论区
推荐文章
Souleigh ✨ Souleigh ✨
4年前
几个有点意思的 CSS 技巧
如果你是一名前端开发人员或者想成为一名开发人员,那么,我今天与你分享的9个CSS技巧,你需要知道一下。现在,我们开始吧。1、学习盒子模型你在学习CSS时,应该避免使用Bootstrap或TailwindCSS等框架,这些工具非常适合构建漂亮的网站,但如果你还不能正确的了解CSS,则建议不要使用这些框架中的任何一个。因为如果你使用了这些工具,你将无法学习
冴羽 冴羽
4年前
VuePress 博客优化之开启 Algolia 全文搜索
前言在中,我们使用VuePress搭建了一个博客,最终的效果查看:。由于VuePress的内置搜索只会为页面的标题、h2、h3以及tags构建搜索索引。如果你需要全文搜索,可则以使用Algolia搜索,本篇讲讲如何申请以及配置Algolia搜索。AlgoliaAlgolia是一个数据库实时搜索服务,能够提供毫秒级的数据库搜
皕杰报表(关于日期时间时分秒显示不出来)
在使用皕杰报表设计器时,数据据里面是日期型,但当你web预览时候,发现有日期时间类型的数据时分秒显示不出来,只有年月日能显示出来,时分秒显示为0:00:00。1.可以使用tochar解决,数据集用selecttochar(flowdate,"yyyyMMddHH:mm:ss")fromtablename2.也可以把数据库日期类型date改成timestamp
皕杰报表中参数和变量的区别
在皕杰报表中,参数是有数据类型的变量,在报表运算过程中作为变量使用。参数那么参数在皕杰报表中具体如何使用呢?1、作为sql语句的where条件:通过给参数赋值可以实现动态查询,给参数赋予不同的值,从而查询出来不同的数据结果。需
Elasticsearch Head插件使用小结
ElasticSearchhead就是一款能连接ElasticSearch搜索引擎,并提供可视化的操作页面对ElasticSearch搜索引擎进行各种设置和数据检索功能的管理插件,如在head插件页面编写RESTful接口风格的请求,就可以对ElasticSearch中的数据进行增删改查、创建或者删除索引等操作。类似于使用navicat工具连接MySQL这种关系型数据库,对数据库做操作
Wesley13 Wesley13
4年前
mysql中时间比较的实现
MySql中时间比较的实现unix\_timestamp()unix\_timestamp函数可以接受一个参数,也可以不使用参数。它的返回值是一个无符号的整数。不使用参数,它返回自1970年1月1日0时0分0秒到现在所经过的秒数,如果使用参数,参数的类型为时间类型或者时间类型的字符串表示,则是从1970010100:00:0
Wesley13 Wesley13
4年前
(绝对有用)iOS获取UUID,并使用keychain存储
UDID被弃用,使用UUID来作为设备的唯一标识。获取到UUID后,如果用NSUserDefaults存储,当程序被卸载后重装时,再获得的UUID和之前就不同了。使用keychain存储可以保证程序卸载重装时,UUID不变。但当刷机或者升级系统后,UUID还是会改变的。但这仍是目前为止最佳的解决办法了,如果有更好的解决办法,欢迎留言。(我整理的解决办法的参
Wesley13 Wesley13
4年前
mysql笔记
如果需要查询id不是连续的一段,最佳的方法就是先找出id,然后用in查询SELECT\FROMtableWHEREidIN(10000,100000,1000000...);索引列上使用in速度是很快的1.SELECT\FROMtableORDERBYidLIMIT1000000,10;
为什么mysql不推荐使用雪花ID作为主键
作者:毛辰飞背景在mysql中设计表的时候,mysql官方推荐不要使用uuid或者不连续不重复的雪花id(long形且唯一),而是推荐连续自增的主键id,官方的推荐是auto_increment,那么为什么不建议采用uuid,使用uuid究
小万哥 小万哥
1年前
NumPy 随机数据分布与 Seaborn 可视化详解
随机数据分布什么是数据分布?数据分布是指数据集中所有可能值出现的频率,并用概率来表示。它描述了数据取值的可能性。在统计学和数据科学中,数据分布是分析数据的重要基础。NumPy中的随机分布NumPy的random模块提供了多种方法来生成服从不同分布的随机数。
小万哥 小万哥
1年前
NumPy 泊松分布模拟与 Seaborn 可视化技巧
泊松分布是描述单位时间间隔内随机事件发生次数的离散概率分布,参数λ表示平均速率。公式为P(k)e^(λ)(λ^k)/k!。NumPy的random.poisson()可生成泊松分布数据。当λ很大时,泊松分布近似正态分布。练习包括模拟顾客到达、比较不同λ下的分布及模拟电话呼叫中心。使用Seaborn可进行可视化。关注公众号LetusCoding获取更多文章。