GithubHelp home page GithubHelp logo

ygzwqzd / lamda-ssl Goto Github PK

View Code? Open in Web Editor NEW
174.0 7.0 16.0 10.68 MB

30 Semi-Supervised Learning Algorithms

License: MIT License

Python 100.00%
machine-learning semi-supervised-learning deep-learning pytorch scikit-learn toolkit python computer-vision natural-language-processing generative-model

lamda-ssl's Introduction

Introduction

In order to promote the research and application of semi-supervised learning (SSL) algorithms, we have developed LAMDA-SSL which is a comprehensive and easy-to-use toolkit for SSL in python. LAMDA-SSL has powerful functions, simple interfaces, and extensive documentation. It integrates statistical SSL algorithms and deep SSL algorithms into the same framework. At present, LAMDA-SSL contains 30 semi-supervised learning algorithms, including 12 algorithms based on statistical machine learning models and 18 algorithms based on deep learning models. LAMDA-SSL also contains 45 data processing methods used for 4 types of data: tabular, image, text, graph and 15 model evaluation criterias used for 3 types of task: classification, regression, and clustering. It is compatible with the popular machine learning toolkit scikit-learn and the popular deep learning toolkit Pytorch. It supports Pipeline mechanism and parameter search functions like scikit-learn and also supports GPU acceleration and distributed training functions like Pytorch. LAMDA-SSL includes multiple modules, such as data management, data transformation, model application, and model deployment, which facilitates the implementation of end-to-end SSL.

At present, LAMDA-SSL has implemented 30 SSL algorithms, including 12 statistical SSL algorithms and 18 deep SSL algorithms.

For statistical SSL, algorithms in LAMDA-SSL can be used for classification, regression and clustering. The algorithms used for classification task include generative method SSGMM; semi-supervised support vector machine methods TSVM and LapSVM; graph-based methods Label Propagation and Label Spreading; disagrement-based methods Co-Training and Tri-Training; ensemble methods SemiBoost and Assemble. The algorithm used for regression task is CoReg. The algorithms used for clustering task include Constrained K Means, Constrained Seed K Means.

For deep SSL, algorithms in LAMDA-SSL can be used for classification and regression. The algorithms used for classification task include consistency methods Ladder Network, Π Model, Temporal Ensembling, Mean Teacher, VAT and UDA; Pseudo label-based methods Pseudo Label and S4L; hybrid methods ICT, MixMatch, ReMixMatch, FixMatch and FlexMatch; deep generative methods ImprovedGAN and SSVAE; deep graph-based methods SDNE and GCN. The algorithms for regression task include consistency method Π Model Reg, Mean Teacher Reg; hybrid method ICT Reg. These 3 deep SSL regression algorithms are our extensions of their prototypes used for classification.

Superiority

  • LAMDA-SSL contains 30 SSL algorithms.
  • LAMDA-SSL can handle 4 types of data and provides 45 functions for data processing and data augmentation.
  • LAMDA-SSL can handle 3 types of tasks and supports 16 metrics for model evaluation.
  • LAMDA-SSL supports 5 hyperparameter search methods including random search, grid search, Bayesian optimization, evolution strategy, and meta-learner.
  • LAMDA-SSL supports both statistical SSL algorithms and deep SSL algorithms and uses a unified implementation framework.
  • LAMDA-SSL is compatible with the popular machine learning toolkit scikit-learn and the popular deep learning toolkit Pytorch.
  • LAMDA-SSL has simple interfaces similar to scikit-learn so that it is easy to use.
  • LAMDA-SSL has powerful functions. It supports Pipeline mechanism and parameter search functions like scikit-learn and also supports GPU acceleration and distributed training functions like Pytorch.
  • LAMDA-SSL considers the needs of different user groups. It provides well tuned default parameters and modules for entry-level users. It also supports flexible module replacement and customization for professional users.
  • LAMDA-SSL has strong extensibility, which is convenient for users to customize new modules and algorithms.
  • LAMDA-SSL has been verified by a large number of experiments and has strong reliability.
  • LAMDA-SSL has comprehensive user documentation.

Dependencies

LAMDA-SSL requires:

  • python (>= 3.7)
  • scikit-learn (>= 1.0.2)
  • torch (>= 1.9.0)
  • torchvision (>= 0.11.2)
  • torchtext (>= 0.11.1)
  • torch-geometric(>= 2.0.3)
  • Pillow(>= 8.4.0)
  • numpy(>= 1.19.2)
  • scipy(>= 1.5.2)
  • pandas(>= 1.3.4)
  • matplotlib(>= 3.5.0)

You can create the environment directly by anaconda.

conda env create -f environment.yaml

Installation

Install from pip

You can download LAMDA-SSL directly from pypi.

pip install LAMDA-SSL

Install from anaconda

You can also download LAMDA-SSL directly from anaconda.

conda install -c ygzwqzd LAMDA-SSL

Install from the source

If you want to try the latest features that have not been released yet, you can install LAMDA-SSL from the source.

git clone https://github.com/ygzwqzd/LAMDA-SSL.git
cd LAMDA-SSL
pip install .

Quick Start

For example, train a FixMatch classifier for CIFAR10.

Firstly, import and initialize CIFAR10.

from LAMDA_SSL.Dataset.Vision.CIFAR10 import CIFAR10

dataset = CIFAR10(root='..\Download\CIFAR10',
                  labeled_size=4000,download=True)
labeled_X, labeled_y = dataset.labeled_X, dataset.labeled_y
unlabeled_X = dataset.unlabeled_X
test_X, test_y = dataset.test_X, dataset.test_y

Then import and initialize FixMatch.

from LAMDA_SSL.Algorithm.Classification.FixMatch import FixMatch
model=FixMatch(threshold=0.95,lambda_u=1.0,mu=7,T=0.5,epoch=1,num_it_epoch=2**20,device='cuda:0')

Next, call the fit() method to complete the training process of the model.

model.fit(X=labeled_X,y=labeled_y,unlabeled_X=unlabeled_X)

Finally, call the predict() method to predict the labels of new samples.

pred_y=model.predict(X=test_X)

Performance

We have evaluated the performance of LAMDA-SSL for semi-supervised classification task on table data using BreastCancer dataset. In this experiment, 30% of the instances are randomly sampled to form the testing dataset by the class distribution. Then 10% of the remaining instances are randomly sampled to form the labeled training dataset and the others are used to form the unlabeled training dataset by dropping their labels. For detailed parameter settings of each method, please refer to the 'Config' module of LAMDA-SSL.

Method Accuracy(%) F1 Score
SSGMM 87.13 86.85
TSVM 95.91 95.56
LapSVM 96.49 96.20
Label Propagation 95.32 94.86
Label Spreading 95.32 94.90
Co-Training 94.74 94.20
Tri-Training 97.66 97.47
Assemble 94.15 93.75
SemiBoost 96.49 96.20

We have evaluated the performance of LAMDA-SSL for semi-supervised regression task on table data using Boston dataset. In this experiment, 30% of the instances are randomly sampled to form the testing dataset by the class distribution. Then 10% of the remaining instances are randomly sampled to form the labeled training dataset and the others are used to form the unlabeled training dataset by dropping their labels. For detailed parameter settings of each method, please refer to the 'Config' module of LAMDA-SSL.

Method Mean Absolute Error Mean Squared Error
CoReg 4.66 59.52
Π Model Reg 4.32 37.64
ICT Reg 4.11 37.14
Mean Teacher Reg 4.51 45.56

We have evaluated the performance of LAMDA-SSL for semi-supervised clustring task on table data using Wine dataset. In this experiment, 20% of the instances are randomly sampled to form the labeled dataset and the others are used to form the unlabeled dataset by dropping their labels. For detailed parameter settings of each method, please refer to the 'Config' module of LAMDA-SSL.

Method Davies Bouldin Score Fowlkes Mallows Score
Constrained k-means 1.76 0.75
Constrained Seed k-means 1.38 0.93

We have evaluated the performance of LAMDA-SSL for semi-supervised clustring task on simple vision data using MNIST dataset. In this experiment, 10% of the instances in training dataset are randomly sampled to form the labeled dataset and the others are used to form the unlabeled dataset by dropping their labels. For detailed parameter settings of each method, please refer to the 'Config' module of LAMDA-SSL.

Method Accuracy(%) F1 Score
Ladder Network 97.37 97.36
ImprovedGAN 98.81 98.81
SSVAE 96.69 96.67

We have evaluated the performance of LAMDA-SSL for semi-supervised classification task on complex vision data using CIFAR10 dataset. In this experiment, 4000 instances in training dataset are randomly sampled to form the labeled training dataset and the others are used to form the unlabeled training dataset by dropping their labels. WideResNet is used as the backbone network. For detailed parameter settings of each method, please refer to the 'Config' module of LAMDA-SSL.

Method Accuracy(%) F1 Score
UDA 95.41 95.40
Π Model 87.09 87.07
Temporal Ensembling 89.30 89.31
Mean Teacher 92.01 91.99
VAT 88.22 88.19
Pseudo Label 85.90 85.85
S4L 89.59 89.54
ICT 92.64 92.62
MixMatch 93.43 93.43
ReMixMatch 96.24 96.24
FixMatch 95.34 95.33
FlexMatch 95.39 95.40

We have evaluated the performance of LAMDA-SSL for semi-supervised classification task on graph data using Cora dataset. In this experiment, 20% of the instances are randomly sampled to form the labeled training dataset and the others are used to form the unlabeled training dataset by dropping their labels. For detailed parameter settings of each method, please refer to the 'Config' module of LAMDA-SSL.

Method Accuracy(%) F1 Score
SDNE 73.78 69.85
GCN 82.04 80.52
GAT 79.13 77.36

Citation

Please cite our paper if you find LAMDA-SSL useful in your work:

@article{jia2022lamdassl,
      title={LAMDA-SSL: Semi-Supervised Learning in Python}, 
      author={Lin-Han Jia and Lan-Zhe Guo and Zhi Zhou and Yu-Feng Li},
      journal={arXiv preprint arXiv:2208.04610},
      year={2022}
}

Contribution

Feel free to contribute in any way you like, we're always open to new ideas and approaches.

  • Open a discussion if you have any question.
  • Feel welcome to open an issue if you've spotted a bug or a performance issue.
  • Fork our project and create a pull request after committing your modifications.
  • Learn more about how to customize modules of LAMDA-SSL from the Usage section of the documentation.

Team

LAMDA-SSL is developed by LAMDA@NJU. Contributors are Lin-Han Jia, Lan-Zhe Guo, Zhi Zhou and Yu-Feng Li.

Contact

If you have any questions, please contact us: Lin-Han Jia[[email protected]].

lamda-ssl's People

Contributors

wnjxyk avatar ygzwqzd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

lamda-ssl's Issues

错别字

中文使用教程页面[Constrained Seed k-means]部分,最后一句话“使用有标注数据参于聚类过成时聚类器更加可靠,……”有错别字,应该为“过程

Add benchmark performance

It would be useful to also see the performance of each SSL model against the purely supervised backbone run on the labeled data.

For example, TSVM vs pure SVM:

import numpy as np
from LAMDA_SSL.Dataset.Tabular.BreastCancer import BreastCancer

dataset = BreastCancer(labeled_size=0.1, stratified=True, shuffle=True)
labeled_X = dataset.labeled_X
labeled_y = dataset.labeled_y
unlabeled_X = dataset.unlabeled_X
unlabeled_y = dataset.unlabeled_y

from sklearn import preprocessing

pre_transform = preprocessing.StandardScaler()
pre_transform.fit(np.vstack([labeled_X, unlabeled_X]))
labeled_X = pre_transform.transform(labeled_X)
unlabeled_X = pre_transform.transform(unlabeled_X)

from LAMDA_SSL.Algorithm.Classification.TSVM import TSVM

# I tried using a range of Cl and Cu, starting from 15 and 0.0001 and then gradually 
# upping Cu and decreasing Cl. It didn't seem to make a difference?
model = TSVM(Cl=1, Cu=1, kernel="linear")

model.fit(X=labeled_X, y=labeled_y, unlabeled_X=unlabeled_X)
pred_y = model.predict()

from LAMDA_SSL.Evaluation.Classifier.Accuracy import Accuracy

score = Accuracy().scoring(unlabeled_y, pred_y)
print(f"SSL TSVM score: {score}")
#> SSL TSVM score: 0.9609375

# Compare with pure SVM
from sklearn import svm
model_sl = svm.SVC()
model_sl.fit(labeled_X, labeled_y)
pred_sl = model_sl.predict(unlabeled_X)
score_sl = Accuracy().scoring(unlabeled_y, pred_sl)
print(f"SL SVM score: {score_sl}")
#> SL SVM score: 0.955078125

How to run my three-category tabular data

Thanks for the great work, I need your help.

If I want to solve the three-category problem, which code should I modify. For example: if there are three categories in the BreastCancer dataset. Because I found that when I didn't modify any code, the confusion matrix only made predictions for the first two classes.

Result/Co_Training_BreastCancer.txt:
accuracy 0.324468085106383
precision 0.2598727091480715
Recall 0.3464646464646464
F1 0.2306878306878307
Confusion_matrix [[0.16666667 0.83333333 0. ]
[0.12727273 0.87272727 0. ]
[0.12727273 0.87272727 0. ]]

I cannot reproduce some of the Examples with deep algorithms (maybe because of pytorch 2.X?)

I just installed LAMDA-SSL from github. It instaled the newer version of all packages, including torch==2.0.1 (pip freeze below)

I cannot reproduce the Example that uses deeplearning

Assemble and others non-deep algorithms work fine:

(luan) Atlas:LAMDA-SSL wainer$ python Examples/Assemble_BreastCancer.py  
(luan) Atlas:LAMDA-SSL wainer$

but

(luan) Atlas:LAMDA-SSL wainer$ python Examples/FixMatch_BreastCancer.py 
Traceback (most recent call last):
  File "/Users/wainer/Dropbox/alunos/luan/LAMDA-SSL/Examples/FixMatch_BreastCancer.py", line 64, in <module>
    model.fit(X=labeled_X,y=labeled_y,unlabeled_X=unlabeled_X)
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 326, in fit
    self.init_train_dataloader()
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 243, in init_train_dataloader
    self._labeled_dataloader, self._unlabeled_dataloader = self._train_dataloader.init_dataloader(
                                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataloader/TrainDataloader.py", line 344, in init_dataloader
    self.labeled_dataloader = self.labeled_dataloader.init_dataloader(dataset=self.labeled_dataset,
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataloader/LabeledDataloader.py", line 86, in init_dataloader
    self.dataloader= DataLoader(dataset=self.dataset,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 245, in __init__
    raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
ValueError: prefetch_factor option could only be specified in multiprocessing.let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.

I have been altering the obvious things such as default prefetch_factor and num_workers but after 2 hours of doing this I still get some problem somewhere. Below is my last attempt, by creating Dataloaders with the appropriate num_workers and prefetch_factor for the FixMatch_BreastCancer.py code, but I am not sure my modifications are correct. Someone is probably much more competent to make these changes...

(luan) Atlas:progs wainer$ python a2.py
Traceback (most recent call last):
  File "/Users/wainer/Dropbox/alunos/luan/progs/a2.py", line 82, in <module>
    model.fit(X=labeled_X,y=labeled_y,unlabeled_X=unlabeled_X)
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 335, in fit
    self.fit_epoch_loop(valid_X,valid_y)
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 311, in fit_epoch_loop
    self.fit_batch_loop(valid_X,valid_y)
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 280, in fit_batch_loop
    for (lb_idx, lb_X, lb_y), (ulb_idx, ulb_X, _) in zip(self._labeled_dataloader, self._unlabeled_dataloader):
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 217, in __getitem__
    Xi, yi = self.apply_transform(Xi, yi)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 185, in apply_transform
    _X = self._transform(X, item)
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 130, in _transform
    X=self._transform(X,item)
      ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 132, in _transform
    X = transform(X)
        ^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/Transformer.py", line 18, in __call__
    return self.fit_transform(X,y,fit_params=fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/sklearn/utils/_set_output.py", line 140, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/Transformer.py", line 30, in fit_transform
    return self.fit(X=X,y=y,fit_params=fit_params).transform(X)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/sklearn/utils/_set_output.py", line 140, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Transform/ToTensor.py", line 51, in transform
    X=torch.Tensor(X)
      ^^^^^^^^^^^^^^^
TypeError: new(): data must be a sequence (got Image)

I would guess that the problem is with the torch 2.X version, but I am not sure.

pip freeze:

(luan) Atlas:progs wainer$ pip freeze
certifi==2023.5.7
charset-normalizer==3.1.0
contourpy==1.0.7
cycler==0.11.0
filelock==3.12.0
flake8==6.0.0
fonttools==4.39.4
idna==3.4
Jinja2==3.1.2
joblib==1.2.0
kiwisolver==1.4.4
LAMDA-SSL @ file:///Users/wainer/Dropbox/alunos/luan/LAMDA-SSL
MarkupSafe==2.1.3
matplotlib==3.7.1
mccabe==0.7.0
mpmath==1.3.0
networkx==3.1
numpy==1.24.3
packaging==23.1
pandas==2.0.2
Pillow==9.5.0
psutil==5.9.5
pycodestyle==2.10.0
pyflakes==3.0.1
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2023.3
requests==2.31.0
scikit-learn==1.2.2
scipy==1.10.1
six==1.16.0
sympy==1.12
threadpoolctl==3.1.0
torch==2.0.1
torch-geometric==2.3.1
torchdata==0.6.1
torchtext==0.15.2
torchvision==0.15.2
tqdm==4.65.0
typing_extensions==4.6.3
tzdata==2023.3
urllib3==2.0.3

What is the appropriate hyperparameter of MeanTeacherReg / ICTReg / PiModelReg when using other dataset?

Hi, Thanks for creating this library.
I am trying to use MeanTeacherReg, ICTReg, and PiModelReg with my custom dataset.
When I used MeanTeacherReg, ICTReg, and PiModelReg without changing the hyperparameter at the example code that was fitted to Boston Dataset, although the model works, the predictions all come to zero.
This means that the model did not learn my custom dataset properly.

What is the appropriate hyperparameter of MeanTeacherReg / ICTReg / PiModelReg when using other dataset?

I thought there are some hyperparemeters to change when using other dataset such as:

  1. num_samples
# sampler
labeled_sampler=RandomSampler(replacement=True,num_samples=64*(2**20))
unlabeled_sampler=RandomSampler(replacement=True)
valid_sampler=SequentialSampler()
test_sampler=SequentialSampler()
  1. batch_size
#dataloader
labeled_dataloader=LabeledDataLoader(batch_size=64,num_workers=0,drop_last=True)
unlabeled_dataloader=UnlabeledDataLoader(num_workers=0,drop_last=True)
valid_dataloader=UnlabeledDataLoader(batch_size=64,num_workers=0,drop_last=False)
test_dataloader=UnlabeledDataLoader(batch_size=64,num_workers=0,drop_last=False)
  1. epoch / num_it_epoch / num_it_total / eval_it
model=MeanTeacherReg(lambda_u=0,warmup=0.4,
               mu=1,weight_decay=5e-4,ema_decay=0.999,
               epoch=1,num_it_epoch=4000,
               num_it_total=4000,
               eval_it=200,device='cpu',
               labeled_dataset=labeled_dataset,
               unlabeled_dataset=unlabeled_dataset,
               valid_dataset=valid_dataset,
               test_dataset=test_dataset,
               labeled_sampler=labeled_sampler,
               unlabeled_sampler=unlabeled_sampler,
               valid_sampler=valid_sampler,
               test_sampler=test_sampler,
               labeled_dataloader=labeled_dataloader,
               unlabeled_dataloader=unlabeled_dataloader,
               valid_dataloader=valid_dataloader,
               test_dataloader=test_dataloader,
               augmentation=augmentation,network=network,
               optimizer=optimizer,scheduler=scheduler,
               evaluation=evaluation,file=file,verbose=True)

The size of my custom dataset,

  • labeled_X is (8760, 10),
  • labeled_y is (8760, 1),
  • Unlabeled_X is (8760, 10),
  • Unlabeled_y is (8760, 1)
  • Test_X is (8760, 10)
  • Test_y is (8760, 1).

At this setting, the model did not train my dataset properly.

Can you provide some example code that you worked the model using a different dataset, not Boston dataset?
Or, is there any tip to implement hyperparameter tuning of MeanTeacherReg, ICTReg, and Pi Model Reg?

Pi Model for MNIST

The lack of colour channels in MNIST (and MNIST-like) datasets means we get errors. It looks like the network implementations aren't set up to handle single channel inputs? Could you parameterise them so that they can be used for MNIST? It would be ideal to just write:

# This:
model = PiModel(channels=1)

# Or this
model = PiModel(shape=(28,28,1))

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.