GithubHelp home page GithubHelp logo

mg2033 / shufflenet Goto Github PK

View Code? Open in Web Editor NEW
384.0 17.0 124.0 196 KB

ShuffleNet Implementation in TensorFlow

License: Apache License 2.0

Python 100.00%
deep-learning shufflenet tensorflow group-convolution computer-vision classification realtime

shufflenet's People

Contributors

mg2033 avatar msiam avatar pomonam avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

shufflenet's Issues

pretrained model

Hi @MG2033,

Did you finish training on ImageNet?
Can your pre-training model be shared?

thanks!

How's the ImageNet training

I'm trying to train this model on ImageNet, but the loss seems to converge slowly after 60k iterations and the loss value is approximately 2.5.

Do you have lower loss value or the similar phenomena above?
I just want to check that I train this model correctly.
Thanks.

HELP why loss keep a high value base flower dataset

I user flower dataset(5 classes and 2500 for train and 500 for val) to create tfrecords file and as input to train,but loss can not reduce and validation acc keep 20% , it is my code have some bug when read tfrecords?

import tensorflow as tf
from tqdm import tqdm
import numpy as np
from utils import load_obj
import matplotlib.pyplot as plt

class Train:
"""Trainer class for the CNN.
It's also responsible for loading/saving the model checkpoints from/to experiments/experiment_name/checkpoint_dir"""

def __init__(self, sess, model, data, summarizer):
    self.sess = sess
    self.model = model
    self.args = self.model.args
    self.saver = tf.train.Saver(max_to_keep=self.args.max_to_keep,
                                keep_checkpoint_every_n_hours=10,
                                save_relative_paths=True)
    # Summarizer references
    self.data = data
    self.summarizer = summarizer

    # Initializing the model
    self.init = None
    self.__init_model()

    # Loading the model checkpoint if exists
    self.__load_imagenet_weights()
    self.__load_model()
    IMAGE_SIZE = 224
    NUM_CLASSES = 5

############################################################################################################
# Model related methods
def __init_model(self):
    print("Initializing the model...")
    self.init = tf.group(tf.global_variables_initializer())
    self.sess.run(self.init)
    print("Model initialized\n\n")

def save_model(self):
    """
    Save Model Checkpoint
    :return:
    """
    print("Saving a checkpoint")
    self.saver.save(self.sess, self.args.checkpoint_dir, self.model.global_step_tensor)
    print("Checkpoint Saved\n\n")

def __load_model(self):
    latest_checkpoint = tf.train.latest_checkpoint(self.args.checkpoint_dir)
    if latest_checkpoint:
        print("Loading model checkpoint {} ...\n".format(latest_checkpoint))
        self.saver.restore(self.sess, latest_checkpoint)
        print("Checkpoint loaded\n\n")
    else:
        print("First time to train!\n\n")

def __load_imagenet_weights(self):
    variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

    print("No pretrained ImageNet weights exist. Skipping...\n\n")

############################################################################################################
# Train and Test methods 

def read_and_decode(self,filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
            features={
            'image': tf.FixedLenFeature([], tf.string),
            'target': tf.FixedLenFeature([], tf.int64),
        })

    # Convert from a scalar string tensor (whose single string has
    image = tf.image.decode_jpeg(features['image'], channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)
    image = tf.clip_by_value(image, 0.0, 1.0)

    # Convert label from a scalar uint8 tensor to an int32 scalar.
    label = tf.cast(features['target'], tf.int32)

    return image, label

def train(self):   
   
    filename_queue = tf.train.string_input_producer(["/home/coolpad/juzhitao/shufflenet/mg2033/ShuffleNet/train1.tfrecords"])
    #train data
    image, label = self.read_and_decode(filename_queue)
    
    images, labels = tf.train.shuffle_batch([image, label], batch_size=50, num_threads=2,capacity=2500,min_after_dequeue=250)        
    
    init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
    self.sess.run(init_op)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)  
    
    for cur_epoch in range(self.model.global_epoch_tensor.eval(self.sess) + 1, self.args.num_epochs + 1, 1):
        # Initialize tqdm          
       
        num_iterations = self.args.train_data_size // self.args.batch_size
        print("num_iterations:::::::::::",num_iterations,'      ','train_data_size=',self.args.train_data_size ,'batch_size:', self.args.batch_size)
        #tqdm_batch = tqdm([self.data.X_train,self.data.y_train], total=num_iterations,
        #                  desc="Epoch-" + str(cur_epoch) + "-")
                   

        # Initialize the current iterations
        cur_iteration = 0

        # Initialize classification accuracy and loss lists
        loss_list = []
        acc_list = []

    
        # Loop by the number of iterations
        print("#####################################cur_epoch==",cur_epoch)
        #for self.data.X_train, self.data.y_train in tqdm_batch:
        for step in tqdm(range(0,  num_iterations),initial=1, total=num_iterations):
            # Get the current iteration for summarizing it
            cur_step = self.model.global_step_tensor.eval(self.sess)
            
            image_train, lable_train = self.sess.run([images,labels])               
            #print(image_train)
            # Feed this variables to the network
            feed_dict = {self.model.X: images,
                         self.model.y: labels,
                         self.model.is_training: True
                         }
            # Run the feed_forward
            _, loss, acc = self.sess.run(
                [self.model.train_op, self.model.loss, self.model.accuracy],
                feed_dict=feed_dict)
            # Append loss and accuracy
            loss_list += [loss]
            acc_list += [acc]

            # Update the Global step
            self.model.global_step_assign_op.eval(session=self.sess,
                                                  feed_dict={self.model.global_step_input: cur_step + 1})

            #self.summarizer.add_summary(cur_step, summaries_merged=summaries_merged)

            if step >= num_iterations - 1:
                avg_loss = np.mean(loss_list)
                avg_acc = np.mean(acc_list)
                # summarize
                #summaries_dict = dict()
                #summaries_dict['loss'] = avg_loss
                #summaries_dict['acc'] = avg_acc

                # summarize
                #self.summarizer.add_summary(cur_step, summaries_dict=summaries_dict)

                # Update the Current Epoch tensor
                self.model.global_epoch_assign_op.eval(session=self.sess,
                                                       feed_dict={self.model.global_epoch_input: cur_epoch + 1})

                # Print in console
                #tqdm_batch.close()
                print("Epoch-" + str(cur_epoch) + " | " + "loss: " + str(avg_loss) + " -" + " acc: " + str(
                    avg_acc)[
                                                                                                       :7])
                # Break the loop to finalize this epoch
                #break

            # Update the current iteration
            cur_iteration += 1

        # Save the current checkpoint
        if cur_epoch % self.args.save_model_every == 0 and cur_epoch != 0:
            self.save_model()

        # Test the model on validation or test data
        if cur_epoch % self.args.test_every == 0:
            self.test('val')
            
    coord.request_stop()
    coord.join(threads)               

def test(self, test_type='val'):
   
    filename_queue = tf.train.string_input_producer(["/home/coolpad/juzhitao/shufflenet/mg2033/ShuffleNet/val1.tfrecords"])
    #val data
    image, label = self.read_and_decode(filename_queue)
    
    images, labels = tf.train.shuffle_batch([image, label], batch_size=50, num_threads=2,capacity=200,min_after_dequeue=50)        
    
    init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
    self.sess.run(init_op)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=self.sess, coord=coord) 
    
    
    num_iterations = self.args.test_data_size // self.args.batch_size
    #tqdm_batch = tqdm(self.data.generate_batch(type=test_type), total=num_iterations,
    #                  desc='Testing')
   
    
    # Initialize classification accuracy and loss lists
    loss_list = []
    acc_list = []
    cur_iteration = 0

    #for X_batch, y_batch in tqdm_batch:
    for step in tqdm(range(0,  num_iterations),initial=1, total=num_iterations):
    
        image_val, label_val = self.sess.run([images,labels])
        # Feed this variables to the network
        feed_dict = {self.model.X: image_val,
                     self.model.y: label_val,
                     self.model.is_training: False
                     }
        # Run the feed_forward
        loss, acc = self.sess.run(
            [self.model.loss, self.model.accuracy],
            feed_dict=feed_dict)

        # Append loss and accuracy
        loss_list += [loss]
        acc_list += [acc]

        if step >= num_iterations - 1:
            avg_loss = np.mean(loss_list)
            avg_acc = np.mean(acc_list)
            print('Test results | test_loss: ' + str(avg_loss) + ' - test_acc: ' + str(avg_acc)[:7])
            #break

        cur_iteration += 1

Number of flops

Hi !
Does someone knows why the number of flops is about twice the number obtained by the authors ?

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.