GithubHelp home page GithubHelp logo

Comments (4)

richardliaw avatar richardliaw commented on July 25, 2024

Can you post the full reproducible script?

Thanks!

from tune-sklearn.

rohan-gt avatar rohan-gt commented on July 25, 2024

Code:

from catboost import CatBoostClassifier
from tune_sklearn import TuneSearchCV
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

# Load breast cancer dataset
cancer = load_breast_cancer()
X = cancer.data
y = cancer.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42)

model = CatBoostClassifier()
param_dists = {
    "auto_class_weights": [None, 'Balanced', 'SqrtBalanced']
}

gs = TuneSearchCV(model, param_dists, n_iter=5, scoring="accuracy")
gs.fit(X_train, y_train)
print(gs.cv_results_)

pred = gs.predict(X_test)
correct = 0
for i in range(len(y_test)):
    if pred[i] == y_test[i]:
        correct += 1
print("Accuracy:", correct / len(pred))

Complete error log:

/usr/local/lib/python3.6/dist-packages/tune_sklearn/tune_basesearch.py:249: UserWarning: Early stopping is not enabled. To enable early stopping, pass in a supported scheduler from Tune and ensure the estimator has `partial_fit`.
  warnings.warn("Early stopping is not enabled. "
Redis failed to start, retrying now.
/usr/local/lib/python3.6/dist-packages/tune_sklearn/tune_basesearch.py:382: UserWarning: Hiding process output by default. To show process output, set verbose=2.
  warnings.warn("Hiding process output by default. "
The `start_trial` operation took 1.363807201385498 seconds to complete, which may be a performance bottleneck.
The dashboard on node 04f6570b4f5d failed with the following error:
Traceback (most recent call last):
  File "/usr/lib/python3.6/asyncio/base_events.py", line 1062, in create_server
    sock.bind(sa)
OSError: [Errno 99] Cannot assign requested address

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/ray/dashboard/dashboard.py", line 961, in <module>
    dashboard.run()
  File "/usr/local/lib/python3.6/dist-packages/ray/dashboard/dashboard.py", line 576, in run
    aiohttp.web.run_app(self.app, host=self.host, port=self.port)
  File "/usr/local/lib/python3.6/dist-packages/aiohttp/web.py", line 433, in run_app
    reuse_port=reuse_port))
  File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
    return future.result()
  File "/usr/local/lib/python3.6/dist-packages/aiohttp/web.py", line 359, in _run_app
    await site.start()
  File "/usr/local/lib/python3.6/dist-packages/aiohttp/web_runner.py", line 104, in start
    reuse_port=self._reuse_port)
  File "/usr/lib/python3.6/asyncio/base_events.py", line 1066, in create_server
    % (sa, err.strerror.lower()))
OSError: [Errno 99] error while attempting to bind on address ('::1', 8265, 0, 0): cannot assign requested address

Trial _Trainable_822f1_00002: Error processing event.
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/ray/tune/trial_runner.py", line 468, in _process_trial
    result = self.trial_executor.fetch_result(trial)
  File "/usr/local/lib/python3.6/dist-packages/ray/tune/ray_trial_executor.py", line 430, in fetch_result
    result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)
  File "/usr/local/lib/python3.6/dist-packages/ray/worker.py", line 1474, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(KeyError): ray::_Trainable.train() (pid=571, ip=172.28.0.2)
  File "/usr/lib/python3.6/queue.py", line 161, in get
    raise Empty
queue.Empty

During handling of the above exception, another exception occurred:

ray::_Trainable.train() (pid=571, ip=172.28.0.2)
  File "python/ray/_raylet.pyx", line 442, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 445, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 446, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 400, in ray._raylet.execute_task.function_executor
  File "/usr/local/lib/python3.6/dist-packages/ray/tune/trainable.py", line 261, in train
    result = self._train()
  File "/usr/local/lib/python3.6/dist-packages/tune_sklearn/_trainable.py", line 128, in _train
    return_train_score=self.return_train_score,
  File "/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_validation.py", line 236, in cross_validate
    for train, test in cv.split(X, y, groups))
  File "/usr/local/lib/python3.6/dist-packages/joblib/parallel.py", line 1029, in __call__
    if self.dispatch_one_batch(iterator):
  File "/usr/local/lib/python3.6/dist-packages/joblib/parallel.py", line 819, in dispatch_one_batch
    islice = list(itertools.islice(iterator, big_batch_size))
  File "/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_validation.py", line 236, in <genexpr>
    for train, test in cv.split(X, y, groups))
  File "/usr/local/lib/python3.6/dist-packages/sklearn/base.py", line 78, in clone
    param2 = params_set[name]
KeyError: 'auto_class_weights'
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-9-62e3c8612684> in <module>()
     19 gs = TuneSearchCV(model, param_dists, n_iter=5, scoring="accuracy")
     20 gs.fit(X_train, y_train)
---> 21 print(gs.cv_results_)
     22 
     23 pred = gs.predict(X_test)

AttributeError: 'TuneSearchCV' object has no attribute 'cv_results_'

from tune-sklearn.

inventormc avatar inventormc commented on July 25, 2024

I tried running this as well and looking into the CatBoost docs. I see the auto_class_weights parameter here but I don't see it as one of the parameters for CatBoostClassifier here. Could it be that this parameter is not supposed to be set to the CatBoostClassifier?

This may not be what you're trying to do, but just to illustrate that it might be an issue with setting the wrong parameters, this script worked for me:

from catboost import CatBoostClassifier
from tune_sklearn import TuneSearchCV
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

# Load breast cancer dataset
cancer = load_breast_cancer()
X = cancer.data
y = cancer.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42)

model = CatBoostClassifier()
param_dists = {
    "class_weights": [[0.8, 0.2], [0.1, 0.9]]
}

gs = TuneSearchCV(model, param_dists, n_iter=5, scoring="accuracy")
gs.fit(X_train, y_train)
print(gs.cv_results_)

pred = gs.predict(X_test)
correct = 0
for i in range(len(y_test)):
    if pred[i] == y_test[i]:
        correct += 1
print("Accuracy:", correct / len(pred))

from tune-sklearn.

rohan-gt avatar rohan-gt commented on July 25, 2024

This issue only exists with the scikit-optimize backend and works with BOHB and Hyperopt and can be closed

from tune-sklearn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.