GithubHelp home page GithubHelp logo

Speeding up HS with LOOCV about imodels HOT 5 CLOSED

csinva avatar csinva commented on August 27, 2024
Speeding up HS with LOOCV

from imodels.

Comments (5)

csinva avatar csinva commented on August 27, 2024 2

Glad to hear it's working for you 😄

This isn't currently implemented but should be easy to do - we'll get it done soon, especially if you're putting this into AutoGluon!

@aagarwal1996 @yanshuotan @OmerRonen - someone wanna take this on?

from imodels.

csinva avatar csinva commented on August 27, 2024 1

Thanks for the question Nick! Indeed we should have publicly released this with some documentation by now - will get to it shortly!

@aagarwal1996 - can you add in the HS + RF / ExtraTrees code with a little doc?

Best,
Chandan

from imodels.

Innixma avatar Innixma commented on August 27, 2024 1

Here is a hyperparameter grid result on Adult sampled to 2000 rows of training data, sorted by test score (AUC)

og = vanilla sklearn
hs=10_n300 = reg_param=10,n_estimators=300

Amazingly, HS with reg_param 50, 100, and 500 and only 10 n_estimators get a better test score than OG with 1000 n_estimators!

                        model  score_test  score_val  pred_time_test  pred_time_val   fit_time  pred_time_test_marginal  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0     RandomForest_hs=50_n300    0.906977   0.904691        0.095684       0.050235   2.733138                 0.095684                0.050235           2.733138            1       True         20
1     RandomForest_hs=50_n100    0.906954   0.905086        0.037076       0.087729   0.922209                 0.037076                0.087729           0.922209            1       True         21
2    RandomForest_hs=50_n1000    0.906939   0.905020        0.287926       0.150001   9.144297                 0.287926                0.150001           9.144297            1       True         19
3    RandomForest_hs=100_n300    0.906854   0.905185        0.098557       0.050372   2.767945                 0.098557                0.050372           2.767945            1       True         26
4   RandomForest_hs=100_n1000    0.906818   0.905020        0.297747       0.155764   9.122136                 0.297747                0.155764           9.122136            1       True         25
5    RandomForest_hs=100_n100    0.906746   0.904823        0.040018       0.022333   0.951915                 0.040018                0.022333           0.951915            1       True         27
6         WeightedEnsemble_L2    0.906464   0.906106        0.822571       0.493818  25.186706                 0.003169                0.000933           0.448992            2       True         43
7      RandomForest_hs=50_n40    0.905625   0.903869        0.019060       0.012893   0.386387                 0.019060                0.012893           0.386387            1       True         22
8     RandomForest_hs=100_n40    0.905619   0.904691        0.019219       0.015171   0.412162                 0.019219                0.015171           0.412162            1       True         28
9     RandomForest_hs=10_n100    0.904859   0.904955        0.038016       0.021933   0.915743                 0.038016                0.021933           0.915743            1       True         15
10    RandomForest_hs=10_n300    0.904659   0.905481        0.098096       0.049020   2.781127                 0.098096                0.049020           2.781127            1       True         14
11   RandomForest_hs=10_n1000    0.904594   0.904659        0.296132       0.150198   9.204939                 0.296132                0.150198           9.204939            1       True         13
12    RandomForest_hs=100_n20    0.904273   0.901533        0.013146       0.013468   0.245283                 0.013146                0.013468           0.245283            1       True         29
13     RandomForest_hs=50_n20    0.903543   0.901336        0.012793       0.009559   0.201715                 0.012793                0.009559           0.201715            1       True         23
14  RandomForest_hs=500_n1000    0.903443   0.903310        0.292724       0.153926   9.449500                 0.292724                0.153926           9.449500            1       True         31
15   RandomForest_hs=500_n300    0.903273   0.904033        0.098899       0.058785   2.789253                 0.098899                0.058785           2.789253            1       True         32
16     RandomForest_hs=10_n40    0.902558   0.899888        0.020626       0.015472   0.387508                 0.020626                0.015472           0.387508            1       True         16
17   RandomForest_hs=500_n100    0.902558   0.904066        0.055166       0.022750   0.920178                 0.055166                0.022750           0.920178            1       True         33
18    RandomForest_hs=500_n40    0.901187   0.901928        0.038185       0.013044   0.375549                 0.038185                0.013044           0.375549            1       True         34
19    RandomForest_hs=500_n20    0.901066   0.899888        0.019732       0.012050   0.197903                 0.019732                0.012050           0.197903            1       True         35
20    RandomForest_hs=100_n10    0.901057   0.897684        0.010498       0.008219   0.132994                 0.010498                0.008219           0.132994            1       True         30
21     RandomForest_hs=50_n10    0.900010   0.896697        0.010349       0.007744   0.106932                 0.010349                0.007744           0.106932            1       True         24
22     RandomForest_hs=1_n100    0.899934   0.903639        0.038109       0.021423   0.925629                 0.038109                0.021423           0.925629            1       True          9
23    RandomForest_hs=1_n1000    0.899916   0.904823        0.292249       0.153203   9.213101                 0.292249                0.153203           9.213101            1       True          7
24    RandomForest_hs=500_n10    0.899908   0.899296        0.011997       0.008104   0.118062                 0.011997                0.008104           0.118062            1       True         36
25     RandomForest_hs=1_n300    0.899902   0.903704        0.097325       0.054082   2.725383                 0.097325                0.054082           2.725383            1       True          8
26     RandomForest_hs=10_n20    0.898988   0.899197        0.017777       0.009672   0.217858                 0.017777                0.009672           0.217858            1       True         17
27   RandomForest_hs=0.1_n100    0.898405   0.901961        0.038886       0.019792   1.188789                 0.038886                0.019792           1.188789            1       True         39
28  RandomForest_hs=0.1_n1000    0.898272   0.903869        0.300892       0.151734  11.785147                 0.300892                0.151734          11.785147            1       True         37
29   RandomForest_hs=0.1_n300    0.898254   0.903211        0.102140       0.059019   3.490166                 0.102140                0.059019           3.490166            1       True         38
30      RandomForest_og_n1000    0.898019   0.903721        0.300054       0.163219   2.121544                 0.300054                0.163219           2.121544            1       True          1
31       RandomForest_og_n300    0.897952   0.903145        0.098421       0.053610   0.375765                 0.098421                0.053610           0.375765            1       True          2
32       RandomForest_og_n100    0.897658   0.901533        0.044796       0.021702   0.158356                 0.044796                0.021702           0.158356            1       True          3
33      RandomForest_hs=1_n40    0.895661   0.894657        0.019914       0.014062   0.378873                 0.019914                0.014062           0.378873            1       True         10
34     RandomForest_hs=10_n10    0.894480   0.893670        0.011885       0.008050   0.107222                 0.011885                0.008050           0.107222            1       True         18
35    RandomForest_hs=0.1_n40    0.894129   0.893604        0.018879       0.011820   0.469591                 0.018879                0.011820           0.469591            1       True         40
36        RandomForest_og_n40    0.891790   0.891910        0.018995       0.014379   0.073364                 0.018995                0.014379           0.073364            1       True          4
37      RandomForest_hs=1_n20    0.890021   0.896532        0.013560       0.010845   0.239654                 0.013560                0.010845           0.239654            1       True         11
38    RandomForest_hs=0.1_n20    0.888886   0.896105        0.013955       0.010772   0.243502                 0.013955                0.010772           0.243502            1       True         41
39        RandomForest_og_n20    0.883450   0.890907        0.013461       0.010590   0.045405                 0.013461                0.010590           0.045405            1       True          5
40      RandomForest_hs=1_n10    0.883443   0.888012        0.010171       0.008007   0.111472                 0.010171                0.008007           0.111472            1       True         12
41    RandomForest_hs=0.1_n10    0.882777   0.887584        0.009979       0.007687   0.138142                 0.009979                0.007687           0.138142            1       True         42
42        RandomForest_og_n10    0.869862   0.874079        0.010571       0.008906   0.030527                 0.010571                0.008906           0.030527            1       True          6

Similar findings with Extra Trees!

                      model  score_test  score_val  pred_time_test  pred_time_val   fit_time  pred_time_test_marginal  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0    ExtraTrees_hs=50_n1000    0.898862   0.895578        0.360272       0.152184  12.766335                 0.360272                0.152184          12.766335            1       True         19
1    ExtraTrees_hs=10_n1000    0.898727   0.898638        0.375004       0.154642  13.135831                 0.375004                0.154642          13.135831            1       True         13
2     ExtraTrees_hs=50_n300    0.898286   0.896006        0.126808       0.057546   3.828264                 0.126808                0.057546           3.828264            1       True         20
3     ExtraTrees_hs=50_n100    0.898091   0.895545        0.050484       0.020662   1.293077                 0.050484                0.020662           1.293077            1       True         21
4     ExtraTrees_hs=10_n300    0.897934   0.898802        0.129800       0.050548   3.923364                 0.129800                0.050548           3.923364            1       True         14
5     ExtraTrees_hs=10_n100    0.897840   0.897848        0.050855       0.021138   1.308959                 0.050855                0.021138           1.308959            1       True         15
6   ExtraTrees_hs=100_n1000    0.897263   0.894295        0.363426       0.148167  12.866699                 0.363426                0.148167          12.866699            1       True         25
7    ExtraTrees_hs=100_n300    0.896848   0.894756        0.129441       0.049175   3.806397                 0.129441                0.049175           3.806397            1       True         26
8      ExtraTrees_hs=50_n40    0.896566   0.893868        0.024298       0.012653   0.611947                 0.024298                0.012653           0.611947            1       True         22
9       WeightedEnsemble_L2    0.896430   0.900053        1.370685       0.559450  50.761817                 0.003386                0.000914           0.446805            2       True         43
10   ExtraTrees_hs=100_n100    0.896420   0.893868        0.046713       0.026253   1.349054                 0.046713                0.026253           1.349054            1       True         27
11     ExtraTrees_hs=10_n40    0.895319   0.892486        0.025234       0.012230   0.530321                 0.025234                0.012230           0.530321            1       True         16
12    ExtraTrees_hs=100_n40    0.895164   0.893341        0.022953       0.013198   0.523787                 0.022953                0.013198           0.523787            1       True         28
13    ExtraTrees_hs=1_n1000    0.894024   0.898441        0.351593       0.148842  13.060638                 0.351593                0.148842          13.060638            1       True          7
14     ExtraTrees_hs=1_n300    0.892950   0.897848        0.125843       0.050800   3.828052                 0.125843                0.050800           3.828052            1       True          8
15     ExtraTrees_hs=1_n100    0.892657   0.895743        0.052110       0.021437   1.310609                 0.052110                0.021437           1.310609            1       True          9
16  ExtraTrees_hs=0.1_n1000    0.891799   0.897914        0.385060       0.153704  16.367127                 0.385060                0.153704          16.367127            1       True         37
17      ExtraTrees_og_n1000    0.891447   0.897816        0.372907       0.220728   2.279349                 0.372907                0.220728           2.279349            1       True          1
18   ExtraTrees_hs=0.1_n300    0.890612   0.896532        0.136989       0.050140   4.911353                 0.136989                0.050140           4.911353            1       True         38
19   ExtraTrees_hs=0.1_n100    0.890299   0.893999        0.050237       0.020473   1.644582                 0.050237                0.020473           1.644582            1       True         39
20     ExtraTrees_hs=50_n10    0.890179   0.895940        0.011973       0.007939   0.142516                 0.011973                0.007939           0.142516            1       True         24
21       ExtraTrees_og_n300    0.890174   0.895924        0.124822       0.050993   0.383026                 0.124822                0.050993           0.383026            1       True          2
22     ExtraTrees_hs=50_n20    0.890151   0.893078        0.015931       0.009639   0.273224                 0.015931                0.009639           0.273224            1       True         23
23  ExtraTrees_hs=500_n1000    0.890060   0.888834        0.361120       0.149288  12.790663                 0.361120                0.149288          12.790663            1       True         31
24   ExtraTrees_hs=500_n300    0.889899   0.889657        0.188049       0.047488   3.835120                 0.188049                0.047488           3.835120            1       True         32
25       ExtraTrees_og_n100    0.889513   0.893440        0.049970       0.024512   0.153703                 0.049970                0.024512           0.153703            1       True          3
26     ExtraTrees_hs=10_n20    0.889493   0.894657        0.016829       0.010440   0.279750                 0.016829                0.010440           0.279750            1       True         17
27   ExtraTrees_hs=500_n100    0.888305   0.888999        0.090708       0.020563   1.297165                 0.090708                0.020563           1.297165            1       True         33
28      ExtraTrees_hs=1_n40    0.888270   0.885972        0.025700       0.013705   0.542178                 0.025700                0.013705           0.542178            1       True         10
29    ExtraTrees_hs=100_n10    0.888002   0.893604        0.011640       0.009796   0.155926                 0.011640                0.009796           0.155926            1       True         30
30     ExtraTrees_hs=10_n10    0.887781   0.893868        0.012549       0.007775   0.151623                 0.012549                0.007775           0.151623            1       True         18
31    ExtraTrees_hs=500_n40    0.887740   0.889031        0.024130       0.012207   0.537571                 0.024130                0.012207           0.537571            1       True         34
32    ExtraTrees_hs=100_n20    0.887497   0.890413        0.015986       0.009240   0.286911                 0.015986                0.009240           0.286911            1       True         29
33    ExtraTrees_hs=0.1_n40    0.885989   0.883340        0.024764       0.011959   0.671531                 0.024764                0.011959           0.671531            1       True         40
34        ExtraTrees_og_n40    0.883635   0.881021        0.025266       0.014620   0.075054                 0.025266                0.014620           0.075054            1       True          4
35      ExtraTrees_hs=1_n20    0.880430   0.889196        0.017302       0.009496   0.323042                 0.017302                0.009496           0.323042            1       True         11
36    ExtraTrees_hs=0.1_n20    0.878517   0.888044        0.016657       0.009894   0.348929                 0.016657                0.009894           0.348929            1       True         41
37    ExtraTrees_hs=500_n20    0.876458   0.881596        0.017727       0.009557   0.295510                 0.017727                0.009557           0.295510            1       True         35
38    ExtraTrees_hs=500_n10    0.876298   0.883208        0.013577       0.007943   0.155019                 0.013577                0.007943           0.155019            1       True         36
39      ExtraTrees_hs=1_n10    0.875740   0.881333        0.012778       0.008113   0.143919                 0.012778                0.008113           0.143919            1       True         12
40    ExtraTrees_hs=0.1_n10    0.874861   0.880445        0.012299       0.007819   0.180453                 0.012299                0.007819           0.180453            1       True         42
41        ExtraTrees_og_n20    0.872678   0.882090        0.016654       0.013582   0.046497                 0.016654                0.013582           0.046497            1       True          5
42        ExtraTrees_og_n10    0.861819   0.862235        0.012075       0.013637   0.047344                 0.012075                0.013637           0.047344            1       True          6

from imodels.

csinva avatar csinva commented on August 27, 2024

Just fixed it in this commit :)

Will have to bump the version to 1.3.3 get it to work, but the code you were using above should work now. Also added a little snippet to the doc here.

Cheers,
Chandan

from imodels.

Innixma avatar Innixma commented on August 27, 2024

Incredible. It works now and is showing very strong results on the Adult dataset (will plan to test more). For example, num_estimators=40 with HSTreeClassifierCV is getting significantly better test scores than RandomForestClassifier(num_estimators=300). The API is perfect, and I can add it in AutoGluon with <10 lines of code.

Question I had related to the ICML oral presentation / paper, I recall the mention of efficient leave-one-out CV to optimize the reg_param. Is this implemented or do I need to pay the cost of HSTreeClassifierCV to get the optimal reg_param value?

from imodels.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.