The following is my code, but I don't know why my model just train one epoch, even i have set a loop:
import sys
import numpy as np
from radio import dataset as ds
from radio import CTImagesMaskedBatch as CTIMB
from radio.pipelines import combine_crops
from radio.models import Keras3DUNet
from radio.models.keras.losses import dice_loss
from tqdm import tqdm
# Todo: Pre data
DIR_CANCER = 'data/cancer/*'
DIR_NCANCER = 'data/ncancer/*'
train_time = 1
my_epoch = 20
bs = 4
loss_his = []
cix = ds.FilesIndex(path=DIR_CANCER, dirs=True)
ncix = ds.FilesIndex(path=DIR_NCANCER, dirs=True)
cancerset = ds.Dataset(index=cix, batch_class=CTIMB)
ncancerset = ds.Dataset(index=ncix, batch_class=CTIMB)
print("Len:", len(cancerset), len(ncancerset))
# Todo: Build & Trian model
unet_config = dict(
input_shape=(1, 32, 64, 64),
num_targets=1,
loss=dice_loss)
from radio.dataset import F, V
train_unet_pipeline = (
combine_crops(cancerset, ncancerset, batch_sizes=(bs, bs))
.init_variable('loss_acc', 0)
.init_variable('current_loss', 0)
.init_variable('loss_history', init_on_each_run=list)
.init_variable('cancer_len', len(cancerset))
.init_variable('ncancer_len', len(ncancerset))
.init_model(
name='3dunet', model_class=Keras3DUNet,
config=unet_config, mode='static'
)
.train_model(
name='3dunet', fetches=[V('loss_acc'), V('cancer_len'), V('ncancer_len')], save_to=V('loss_acc'),
x=F(CTIMB.unpack, component='images', data_format='channels_first'),
y=F(CTIMB.unpack, component='masks', data_format='channels_first')
)
# .run(batch_size=bs, n_epochs=4, drop_last=True, bar=True)
.print("loss and acc is:", V('loss_acc'))
.update_variable('loss_history', value=V('loss_acc'), mode='a')
# Notice: here we use train_on_batch to train our model, train_on_batch return 2 metrics ['loss', 'acc'].
)
for i in range(my_epoch):
print(f"epoch {i}")
t = train_unet_pipeline.run(epoch=1, batch_size=bs)
# for i in tqdm(range(my_epoch)):
# train_unet_pipeline.run(batch_size=bs, n_epochs=1, shuffle=True) # cancer + ncancer = 8
# # here epoch is nothing, loop train_pip.run to get multi epoch
# loss_his.append(train_unet_pipeline.get_variable('loss_history')) # list: [ [loss1, acc1], [loss2, acc2], ... ]
# keras_unet = train_unet_pipeline.get_model_by_name('3dunet')
# keras_unet.save('data/weighsts_'+str(train_time))
# loss_his = np.array(loss_his)
# np.save('data/loss_his'+str(train_time)+'.npy', loss_his)