Comments (4)
Can you post the full reproducible script?
Thanks!
from tune-sklearn.
Code:
from catboost import CatBoostClassifier
from tune_sklearn import TuneSearchCV
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
# Load breast cancer dataset
cancer = load_breast_cancer()
X = cancer.data
y = cancer.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42)
model = CatBoostClassifier()
param_dists = {
"auto_class_weights": [None, 'Balanced', 'SqrtBalanced']
}
gs = TuneSearchCV(model, param_dists, n_iter=5, scoring="accuracy")
gs.fit(X_train, y_train)
print(gs.cv_results_)
pred = gs.predict(X_test)
correct = 0
for i in range(len(y_test)):
if pred[i] == y_test[i]:
correct += 1
print("Accuracy:", correct / len(pred))
Complete error log:
/usr/local/lib/python3.6/dist-packages/tune_sklearn/tune_basesearch.py:249: UserWarning: Early stopping is not enabled. To enable early stopping, pass in a supported scheduler from Tune and ensure the estimator has `partial_fit`.
warnings.warn("Early stopping is not enabled. "
Redis failed to start, retrying now.
/usr/local/lib/python3.6/dist-packages/tune_sklearn/tune_basesearch.py:382: UserWarning: Hiding process output by default. To show process output, set verbose=2.
warnings.warn("Hiding process output by default. "
The `start_trial` operation took 1.363807201385498 seconds to complete, which may be a performance bottleneck.
The dashboard on node 04f6570b4f5d failed with the following error:
Traceback (most recent call last):
File "/usr/lib/python3.6/asyncio/base_events.py", line 1062, in create_server
sock.bind(sa)
OSError: [Errno 99] Cannot assign requested address
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/ray/dashboard/dashboard.py", line 961, in <module>
dashboard.run()
File "/usr/local/lib/python3.6/dist-packages/ray/dashboard/dashboard.py", line 576, in run
aiohttp.web.run_app(self.app, host=self.host, port=self.port)
File "/usr/local/lib/python3.6/dist-packages/aiohttp/web.py", line 433, in run_app
reuse_port=reuse_port))
File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
return future.result()
File "/usr/local/lib/python3.6/dist-packages/aiohttp/web.py", line 359, in _run_app
await site.start()
File "/usr/local/lib/python3.6/dist-packages/aiohttp/web_runner.py", line 104, in start
reuse_port=self._reuse_port)
File "/usr/lib/python3.6/asyncio/base_events.py", line 1066, in create_server
% (sa, err.strerror.lower()))
OSError: [Errno 99] error while attempting to bind on address ('::1', 8265, 0, 0): cannot assign requested address
Trial _Trainable_822f1_00002: Error processing event.
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/ray/tune/trial_runner.py", line 468, in _process_trial
result = self.trial_executor.fetch_result(trial)
File "/usr/local/lib/python3.6/dist-packages/ray/tune/ray_trial_executor.py", line 430, in fetch_result
result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)
File "/usr/local/lib/python3.6/dist-packages/ray/worker.py", line 1474, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(KeyError): ray::_Trainable.train() (pid=571, ip=172.28.0.2)
File "/usr/lib/python3.6/queue.py", line 161, in get
raise Empty
queue.Empty
During handling of the above exception, another exception occurred:
ray::_Trainable.train() (pid=571, ip=172.28.0.2)
File "python/ray/_raylet.pyx", line 442, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 445, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 446, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 400, in ray._raylet.execute_task.function_executor
File "/usr/local/lib/python3.6/dist-packages/ray/tune/trainable.py", line 261, in train
result = self._train()
File "/usr/local/lib/python3.6/dist-packages/tune_sklearn/_trainable.py", line 128, in _train
return_train_score=self.return_train_score,
File "/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_validation.py", line 236, in cross_validate
for train, test in cv.split(X, y, groups))
File "/usr/local/lib/python3.6/dist-packages/joblib/parallel.py", line 1029, in __call__
if self.dispatch_one_batch(iterator):
File "/usr/local/lib/python3.6/dist-packages/joblib/parallel.py", line 819, in dispatch_one_batch
islice = list(itertools.islice(iterator, big_batch_size))
File "/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_validation.py", line 236, in <genexpr>
for train, test in cv.split(X, y, groups))
File "/usr/local/lib/python3.6/dist-packages/sklearn/base.py", line 78, in clone
param2 = params_set[name]
KeyError: 'auto_class_weights'
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-9-62e3c8612684> in <module>()
19 gs = TuneSearchCV(model, param_dists, n_iter=5, scoring="accuracy")
20 gs.fit(X_train, y_train)
---> 21 print(gs.cv_results_)
22
23 pred = gs.predict(X_test)
AttributeError: 'TuneSearchCV' object has no attribute 'cv_results_'
from tune-sklearn.
I tried running this as well and looking into the CatBoost docs. I see the auto_class_weights
parameter here but I don't see it as one of the parameters for CatBoostClassifier
here. Could it be that this parameter is not supposed to be set to the CatBoostClassifier
?
This may not be what you're trying to do, but just to illustrate that it might be an issue with setting the wrong parameters, this script worked for me:
from catboost import CatBoostClassifier
from tune_sklearn import TuneSearchCV
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
# Load breast cancer dataset
cancer = load_breast_cancer()
X = cancer.data
y = cancer.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42)
model = CatBoostClassifier()
param_dists = {
"class_weights": [[0.8, 0.2], [0.1, 0.9]]
}
gs = TuneSearchCV(model, param_dists, n_iter=5, scoring="accuracy")
gs.fit(X_train, y_train)
print(gs.cv_results_)
pred = gs.predict(X_test)
correct = 0
for i in range(len(y_test)):
if pred[i] == y_test[i]:
correct += 1
print("Accuracy:", correct / len(pred))
from tune-sklearn.
This issue only exists with the scikit-optimize backend and works with BOHB and Hyperopt and can be closed
from tune-sklearn.
Related Issues (20)
- TuneSearchCV not correctly handling error_score parameter HOT 5
- Save TuneSearchCV object with tensorflow and keras models HOT 7
- Can't suppress warning messages through standard python methods HOT 4
- n_jobs doesn't seem to be taken into account by TuneSearchCV HOT 3
- Resuming from checkpoint?
- Fail to run the conda installed tune_sklearn package HOT 2
- sk_n_jobs bug
- "training_iteration" from TuneSearchCV is always 1, and accuracy does not improve over time
- For TuneGridSearchCV: Where should I put reuse_actors=True?
- AttributeError: 'str' object has no attribute 'setup'
- TuneSearchCV doesn't seem to search for modules in alternative locations included in the PATH environment variable HOT 2
- during pickling there is an error HOT 1
- False Error log complains failed to read the result of trails
- How to tune Skorch model using GPU
- Since Ray-2.7.0, fetch_trial_dataframes is deprecated and raise an DeprecationWarning exception HOT 2
- No experiment checkpoint file of form 'experiment_state-*.json' was found HOT 1
- context is not passed with `set_config`
- Label management problem for Multilable classification
- Is it possible to save all models when doing TuneSearchCV or equivalent?
- Installation fails on Python 3.11/Windows
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tune-sklearn.