Comments (5)
Glad to hear it's working for you 😄
This isn't currently implemented but should be easy to do - we'll get it done soon, especially if you're putting this into AutoGluon!
@aagarwal1996 @yanshuotan @OmerRonen - someone wanna take this on?
from imodels.
Thanks for the question Nick! Indeed we should have publicly released this with some documentation by now - will get to it shortly!
@aagarwal1996 - can you add in the HS + RF / ExtraTrees code with a little doc?
Best,
Chandan
from imodels.
Here is a hyperparameter grid result on Adult sampled to 2000 rows of training data, sorted by test score (AUC)
og = vanilla sklearn
hs=10_n300 = reg_param=10,n_estimators=300
Amazingly, HS with reg_param 50, 100, and 500 and only 10 n_estimators get a better test score than OG with 1000 n_estimators!
model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 RandomForest_hs=50_n300 0.906977 0.904691 0.095684 0.050235 2.733138 0.095684 0.050235 2.733138 1 True 20
1 RandomForest_hs=50_n100 0.906954 0.905086 0.037076 0.087729 0.922209 0.037076 0.087729 0.922209 1 True 21
2 RandomForest_hs=50_n1000 0.906939 0.905020 0.287926 0.150001 9.144297 0.287926 0.150001 9.144297 1 True 19
3 RandomForest_hs=100_n300 0.906854 0.905185 0.098557 0.050372 2.767945 0.098557 0.050372 2.767945 1 True 26
4 RandomForest_hs=100_n1000 0.906818 0.905020 0.297747 0.155764 9.122136 0.297747 0.155764 9.122136 1 True 25
5 RandomForest_hs=100_n100 0.906746 0.904823 0.040018 0.022333 0.951915 0.040018 0.022333 0.951915 1 True 27
6 WeightedEnsemble_L2 0.906464 0.906106 0.822571 0.493818 25.186706 0.003169 0.000933 0.448992 2 True 43
7 RandomForest_hs=50_n40 0.905625 0.903869 0.019060 0.012893 0.386387 0.019060 0.012893 0.386387 1 True 22
8 RandomForest_hs=100_n40 0.905619 0.904691 0.019219 0.015171 0.412162 0.019219 0.015171 0.412162 1 True 28
9 RandomForest_hs=10_n100 0.904859 0.904955 0.038016 0.021933 0.915743 0.038016 0.021933 0.915743 1 True 15
10 RandomForest_hs=10_n300 0.904659 0.905481 0.098096 0.049020 2.781127 0.098096 0.049020 2.781127 1 True 14
11 RandomForest_hs=10_n1000 0.904594 0.904659 0.296132 0.150198 9.204939 0.296132 0.150198 9.204939 1 True 13
12 RandomForest_hs=100_n20 0.904273 0.901533 0.013146 0.013468 0.245283 0.013146 0.013468 0.245283 1 True 29
13 RandomForest_hs=50_n20 0.903543 0.901336 0.012793 0.009559 0.201715 0.012793 0.009559 0.201715 1 True 23
14 RandomForest_hs=500_n1000 0.903443 0.903310 0.292724 0.153926 9.449500 0.292724 0.153926 9.449500 1 True 31
15 RandomForest_hs=500_n300 0.903273 0.904033 0.098899 0.058785 2.789253 0.098899 0.058785 2.789253 1 True 32
16 RandomForest_hs=10_n40 0.902558 0.899888 0.020626 0.015472 0.387508 0.020626 0.015472 0.387508 1 True 16
17 RandomForest_hs=500_n100 0.902558 0.904066 0.055166 0.022750 0.920178 0.055166 0.022750 0.920178 1 True 33
18 RandomForest_hs=500_n40 0.901187 0.901928 0.038185 0.013044 0.375549 0.038185 0.013044 0.375549 1 True 34
19 RandomForest_hs=500_n20 0.901066 0.899888 0.019732 0.012050 0.197903 0.019732 0.012050 0.197903 1 True 35
20 RandomForest_hs=100_n10 0.901057 0.897684 0.010498 0.008219 0.132994 0.010498 0.008219 0.132994 1 True 30
21 RandomForest_hs=50_n10 0.900010 0.896697 0.010349 0.007744 0.106932 0.010349 0.007744 0.106932 1 True 24
22 RandomForest_hs=1_n100 0.899934 0.903639 0.038109 0.021423 0.925629 0.038109 0.021423 0.925629 1 True 9
23 RandomForest_hs=1_n1000 0.899916 0.904823 0.292249 0.153203 9.213101 0.292249 0.153203 9.213101 1 True 7
24 RandomForest_hs=500_n10 0.899908 0.899296 0.011997 0.008104 0.118062 0.011997 0.008104 0.118062 1 True 36
25 RandomForest_hs=1_n300 0.899902 0.903704 0.097325 0.054082 2.725383 0.097325 0.054082 2.725383 1 True 8
26 RandomForest_hs=10_n20 0.898988 0.899197 0.017777 0.009672 0.217858 0.017777 0.009672 0.217858 1 True 17
27 RandomForest_hs=0.1_n100 0.898405 0.901961 0.038886 0.019792 1.188789 0.038886 0.019792 1.188789 1 True 39
28 RandomForest_hs=0.1_n1000 0.898272 0.903869 0.300892 0.151734 11.785147 0.300892 0.151734 11.785147 1 True 37
29 RandomForest_hs=0.1_n300 0.898254 0.903211 0.102140 0.059019 3.490166 0.102140 0.059019 3.490166 1 True 38
30 RandomForest_og_n1000 0.898019 0.903721 0.300054 0.163219 2.121544 0.300054 0.163219 2.121544 1 True 1
31 RandomForest_og_n300 0.897952 0.903145 0.098421 0.053610 0.375765 0.098421 0.053610 0.375765 1 True 2
32 RandomForest_og_n100 0.897658 0.901533 0.044796 0.021702 0.158356 0.044796 0.021702 0.158356 1 True 3
33 RandomForest_hs=1_n40 0.895661 0.894657 0.019914 0.014062 0.378873 0.019914 0.014062 0.378873 1 True 10
34 RandomForest_hs=10_n10 0.894480 0.893670 0.011885 0.008050 0.107222 0.011885 0.008050 0.107222 1 True 18
35 RandomForest_hs=0.1_n40 0.894129 0.893604 0.018879 0.011820 0.469591 0.018879 0.011820 0.469591 1 True 40
36 RandomForest_og_n40 0.891790 0.891910 0.018995 0.014379 0.073364 0.018995 0.014379 0.073364 1 True 4
37 RandomForest_hs=1_n20 0.890021 0.896532 0.013560 0.010845 0.239654 0.013560 0.010845 0.239654 1 True 11
38 RandomForest_hs=0.1_n20 0.888886 0.896105 0.013955 0.010772 0.243502 0.013955 0.010772 0.243502 1 True 41
39 RandomForest_og_n20 0.883450 0.890907 0.013461 0.010590 0.045405 0.013461 0.010590 0.045405 1 True 5
40 RandomForest_hs=1_n10 0.883443 0.888012 0.010171 0.008007 0.111472 0.010171 0.008007 0.111472 1 True 12
41 RandomForest_hs=0.1_n10 0.882777 0.887584 0.009979 0.007687 0.138142 0.009979 0.007687 0.138142 1 True 42
42 RandomForest_og_n10 0.869862 0.874079 0.010571 0.008906 0.030527 0.010571 0.008906 0.030527 1 True 6
Similar findings with Extra Trees!
model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 ExtraTrees_hs=50_n1000 0.898862 0.895578 0.360272 0.152184 12.766335 0.360272 0.152184 12.766335 1 True 19
1 ExtraTrees_hs=10_n1000 0.898727 0.898638 0.375004 0.154642 13.135831 0.375004 0.154642 13.135831 1 True 13
2 ExtraTrees_hs=50_n300 0.898286 0.896006 0.126808 0.057546 3.828264 0.126808 0.057546 3.828264 1 True 20
3 ExtraTrees_hs=50_n100 0.898091 0.895545 0.050484 0.020662 1.293077 0.050484 0.020662 1.293077 1 True 21
4 ExtraTrees_hs=10_n300 0.897934 0.898802 0.129800 0.050548 3.923364 0.129800 0.050548 3.923364 1 True 14
5 ExtraTrees_hs=10_n100 0.897840 0.897848 0.050855 0.021138 1.308959 0.050855 0.021138 1.308959 1 True 15
6 ExtraTrees_hs=100_n1000 0.897263 0.894295 0.363426 0.148167 12.866699 0.363426 0.148167 12.866699 1 True 25
7 ExtraTrees_hs=100_n300 0.896848 0.894756 0.129441 0.049175 3.806397 0.129441 0.049175 3.806397 1 True 26
8 ExtraTrees_hs=50_n40 0.896566 0.893868 0.024298 0.012653 0.611947 0.024298 0.012653 0.611947 1 True 22
9 WeightedEnsemble_L2 0.896430 0.900053 1.370685 0.559450 50.761817 0.003386 0.000914 0.446805 2 True 43
10 ExtraTrees_hs=100_n100 0.896420 0.893868 0.046713 0.026253 1.349054 0.046713 0.026253 1.349054 1 True 27
11 ExtraTrees_hs=10_n40 0.895319 0.892486 0.025234 0.012230 0.530321 0.025234 0.012230 0.530321 1 True 16
12 ExtraTrees_hs=100_n40 0.895164 0.893341 0.022953 0.013198 0.523787 0.022953 0.013198 0.523787 1 True 28
13 ExtraTrees_hs=1_n1000 0.894024 0.898441 0.351593 0.148842 13.060638 0.351593 0.148842 13.060638 1 True 7
14 ExtraTrees_hs=1_n300 0.892950 0.897848 0.125843 0.050800 3.828052 0.125843 0.050800 3.828052 1 True 8
15 ExtraTrees_hs=1_n100 0.892657 0.895743 0.052110 0.021437 1.310609 0.052110 0.021437 1.310609 1 True 9
16 ExtraTrees_hs=0.1_n1000 0.891799 0.897914 0.385060 0.153704 16.367127 0.385060 0.153704 16.367127 1 True 37
17 ExtraTrees_og_n1000 0.891447 0.897816 0.372907 0.220728 2.279349 0.372907 0.220728 2.279349 1 True 1
18 ExtraTrees_hs=0.1_n300 0.890612 0.896532 0.136989 0.050140 4.911353 0.136989 0.050140 4.911353 1 True 38
19 ExtraTrees_hs=0.1_n100 0.890299 0.893999 0.050237 0.020473 1.644582 0.050237 0.020473 1.644582 1 True 39
20 ExtraTrees_hs=50_n10 0.890179 0.895940 0.011973 0.007939 0.142516 0.011973 0.007939 0.142516 1 True 24
21 ExtraTrees_og_n300 0.890174 0.895924 0.124822 0.050993 0.383026 0.124822 0.050993 0.383026 1 True 2
22 ExtraTrees_hs=50_n20 0.890151 0.893078 0.015931 0.009639 0.273224 0.015931 0.009639 0.273224 1 True 23
23 ExtraTrees_hs=500_n1000 0.890060 0.888834 0.361120 0.149288 12.790663 0.361120 0.149288 12.790663 1 True 31
24 ExtraTrees_hs=500_n300 0.889899 0.889657 0.188049 0.047488 3.835120 0.188049 0.047488 3.835120 1 True 32
25 ExtraTrees_og_n100 0.889513 0.893440 0.049970 0.024512 0.153703 0.049970 0.024512 0.153703 1 True 3
26 ExtraTrees_hs=10_n20 0.889493 0.894657 0.016829 0.010440 0.279750 0.016829 0.010440 0.279750 1 True 17
27 ExtraTrees_hs=500_n100 0.888305 0.888999 0.090708 0.020563 1.297165 0.090708 0.020563 1.297165 1 True 33
28 ExtraTrees_hs=1_n40 0.888270 0.885972 0.025700 0.013705 0.542178 0.025700 0.013705 0.542178 1 True 10
29 ExtraTrees_hs=100_n10 0.888002 0.893604 0.011640 0.009796 0.155926 0.011640 0.009796 0.155926 1 True 30
30 ExtraTrees_hs=10_n10 0.887781 0.893868 0.012549 0.007775 0.151623 0.012549 0.007775 0.151623 1 True 18
31 ExtraTrees_hs=500_n40 0.887740 0.889031 0.024130 0.012207 0.537571 0.024130 0.012207 0.537571 1 True 34
32 ExtraTrees_hs=100_n20 0.887497 0.890413 0.015986 0.009240 0.286911 0.015986 0.009240 0.286911 1 True 29
33 ExtraTrees_hs=0.1_n40 0.885989 0.883340 0.024764 0.011959 0.671531 0.024764 0.011959 0.671531 1 True 40
34 ExtraTrees_og_n40 0.883635 0.881021 0.025266 0.014620 0.075054 0.025266 0.014620 0.075054 1 True 4
35 ExtraTrees_hs=1_n20 0.880430 0.889196 0.017302 0.009496 0.323042 0.017302 0.009496 0.323042 1 True 11
36 ExtraTrees_hs=0.1_n20 0.878517 0.888044 0.016657 0.009894 0.348929 0.016657 0.009894 0.348929 1 True 41
37 ExtraTrees_hs=500_n20 0.876458 0.881596 0.017727 0.009557 0.295510 0.017727 0.009557 0.295510 1 True 35
38 ExtraTrees_hs=500_n10 0.876298 0.883208 0.013577 0.007943 0.155019 0.013577 0.007943 0.155019 1 True 36
39 ExtraTrees_hs=1_n10 0.875740 0.881333 0.012778 0.008113 0.143919 0.012778 0.008113 0.143919 1 True 12
40 ExtraTrees_hs=0.1_n10 0.874861 0.880445 0.012299 0.007819 0.180453 0.012299 0.007819 0.180453 1 True 42
41 ExtraTrees_og_n20 0.872678 0.882090 0.016654 0.013582 0.046497 0.016654 0.013582 0.046497 1 True 5
42 ExtraTrees_og_n10 0.861819 0.862235 0.012075 0.013637 0.047344 0.012075 0.013637 0.047344 1 True 6
from imodels.
Just fixed it in this commit :)
Will have to bump the version to 1.3.3
get it to work, but the code you were using above should work now. Also added a little snippet to the doc here.
Cheers,
Chandan
from imodels.
Incredible. It works now and is showing very strong results on the Adult dataset (will plan to test more). For example, num_estimators=40 with HSTreeClassifierCV is getting significantly better test scores than RandomForestClassifier(num_estimators=300). The API is perfect, and I can add it in AutoGluon with <10 lines of code.
Question I had related to the ICML oral presentation / paper, I recall the mention of efficient leave-one-out CV to optimize the reg_param. Is this implemented or do I need to pay the cost of HSTreeClassifierCV to get the optimal reg_param value?
from imodels.
Related Issues (20)
- HSTree sample weight as positional argument is incompatible with scikit API HOT 3
- FIGS dtreeviz support broken HOT 5
- 'RuleFitRegressor' object has no attribute 'get_rules' HOT 1
- Understanding the engineered features in Autogluaon. HOT 6
- Importing imodels changes default matplotlib plot size config HOT 1
- Link to Skope Rules in README brings elsewhere
- Possible bugs in GreedyRuleListClassifier HOT 3
- Rules list cutoffs are not printed in string representations of GreedyRulesListClassifier HOT 2
- GreedyRulesListClassifier rules don't describe points in all decision regions in some cases HOT 2
- Is library supporting scikit-learn API such as XGBoost or LightGBM compatible to HSTreeClassifier/Regressor?
- Inconsistent types for RuleFit._extract_rules() HOT 1
- HSTree LOOCV Speed-up
- BoostedRulesClassifier, BoostedRulesRegressor not working with sklearn cross_validate HOT 3
- FIGSRegressor failed to calculate importance_data_tree
- 'FIGS' object has no attribute 'decision_function' HOT 1
- DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future HOT 1
- Problem importing functions from module `imodels.importance` HOT 2
- Bugs about fit SLIMClassifier, BayesianRuleSetClassifier, SlipperClassifier, TaoTreeRegressor
- How to extract rules from FIGSClassifier ? HOT 3
- How to implement the importance of variable and rule in the RuleFitClassifier algorithm?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from imodels.