def get_class_weights(dataset_name):
# pre-calculate the class weight
if dataset_name == 'S3DIS_A1':
num_per_class = [0.27362621, 0.3134626, 0.18798782, 1.38965602, 1.44210271, 0.86639497, 1.07227331,
1., 1.05912352, 1.92726327, 0.52329938, 2.04783419, 0.5104427]
elif dataset_name == 'S3DIS_A2':
num_per_class = [0.29036634, 0.34709631, 0.19514767, 1.20129272, 1.39663689, 0.87889087, 1.11586938,
1., 1.54599972, 1.87057415, 0.56458097, 1.87316536, 0.51576885]
elif dataset_name == 'S3DIS_A3':
num_per_class = [0.27578885, 0.32039725, 0.19055443, 1.14914046, 1.46885687, 0.85450877, 1.05414776,
1., 1.09680025, 2.09280004, 0.59355243, 1.95746691, 0.50429199]
elif dataset_name == 'S3DIS_A4':
num_per_class = [0.27667177, 0.32612854, 0.19886974, 1.18282174, 1.52145143, 0.8793782, 1.14202999,
1., 1.0857859, 1.89738584, 0.5964717, 1.95820557, 0.52113351]
elif dataset_name == 'S3DIS_A5':
num_per_class = [0.28459923, 0.32990557, 0.1999722, 1.20798185, 1.33784535, 1., 0.93323316, 1.0753585,
1.00199521, 1.53657772, 0.7987055, 1.82384844, 0.48565471]
elif dataset_name == 'S3DIS_A6':
num_per_class = [0.29442441, 0.37941846, 0.21360804, 0.9812721, 1.40968965, 0.88577139, 1.,
1.09387107, 1.53238009, 1.61365643, 1.15693894, 1.57821041, 0.47342451]
elif dataset_name == 'ScanNet_train':
num_per_class = [0.32051547, 0.1980627, 0.2621471, 0.74563083, 0.52141879, 0.65918949, 0.73560561, 1.03624985,
1.00063147, 0.90604468, 0.43435155, 3.91494446, 1.94558718, 1., 0.54871637, 2.13587716,
1.13931665, 2.06423695, 5.59103054, 1.08557339, 1.35027497]
elif dataset_name == 'ScanNet_trainval':
num_per_class = [0.32051547, 0.1980627, 0.2621471, 0.74563083, 0.52141879, 0.65918949, 0.73560561, 1.03624985,
1.00063147, 0.90604468, 0.43435155, 3.91494446, 1.94558718, 1., 0.54871637, 2.13587716,
1.13931665, 2.06423695, 5.59103054, 1.08557339, 1.35027497]
else:
raise Exception('No Prepared Class Weights of Dataset')
return torch.FloatTensor(num_per_class)