GithubHelp home page GithubHelp logo

feddyn's People

Contributors

alpemreacar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

feddyn's Issues

loss for feddyn

Current way loss of feddyn is calculated (in utils_general):

loss_algo = alpha_coef * torch.sum(local_par_list * (-avg_mdl_param + local_grad_vector))

I think what you wanted to do (unless I'm missing something) is:

loss_algo = alpha_coef * torch.sum(local_par_list * (0.5*local_par_list - avg_mdl_param + local_grad_vector))

Maybe is much clearer (match paper equation) to write something like:

loss_algo = alpha_coef * torch.sum(local_par_list *  local_grad_vector) + 0.5*alpha_coef*(local_par_list - avg_mdl_param)**2

Feel free to correct me if I'm wrong.

About the server state h?

Thanks for your awesome works. I'm trying to re-implement your algorithm. But when I read the source code, I cannot find where is the server state $h^t$.

If I understand correctly, in this line https://github.com/alpemreacar/FedDyn/blob/48a19fac440ef079ce563da8e0c2896f8256fef9/utils_methods.py#L389, the local_param_list_curr is the local grad $\nabla L_k (\theta_k^t)$, cld_mdl_param_tensor is the global model parameter $\theta^{t-1}$,
In this line https://github.com/alpemreacar/FedDyn/blob/48a19fac440ef079ce563da8e0c2896f8256fef9/utils_methods.py#L397, the cld_mdl_param is the new global model parameter $\theta^t$, and it seems that the np.mean(local_param_list, axis=0) is the $-\frac{1}{\alpha} h^t$.

Thus, the code means that $h^t = -\frac{\alpha}{m} \sum_{k \in \left[ m \right] } \nabla L_k (\theta_k) $,
in which here I ignore the $t$ because the summation is conducted on all clients, we cannot know the timestamp of $\nabla L_k (\theta_k^t)$ because of randomly client selection.

So here the actual $h^t$ is not strictly calculated as
$h^t = h^{t-1} - \alpha \frac{1}{m} (\sum_{k\in {P}_t} \theta_k^t - \theta^{t-1} )$.

a minor mistake in calculation of the loss of FedProx

Current way loss of fedprox is calculated (in utils_general):

loss_algo = mu/2 * torch.sum(local_par_list * local_par_list)
loss_algo = -mu * torch.sum(local_par_list * avg_model_param_)

I think what you wanted to do (unless I'm missing something) is:

loss_algo = mu/2 * torch.sum(local_par_list * local_par_list)
loss_algo = loss_algo - mu * torch.sum(local_par_list * avg_model_param_)

Cheers,
F. Varno

The preprocess order of {Horizontal flip, Random cropping} and {normalize}.

Subject: I find that the order of {Horizontal flip, Random cropping} and {normalize} is somewhat different from the typical operation order, which I think will not affect the experimental conclusion, yet is interesting here.

Detail: Take CIFAR-10 (iid and FedAvg) as an example.

{Normalize.} In utils_dataset.py's lines 32-41. We generate trn_load with normalized, which then is used to generate self.trn_x and self.trn_y (lines 61-65) and then further affect clnt_x, clnt_y (lines 204-205 when self.rule == 'iid'). This means that clnt_x, clnt_y contain the normalized data samples.

transform = transforms.Compose([transforms.ToTensor(),
            transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])])

trnset = torchvision.datasets.CIFAR10(root='%s/Raw' %self.data_path,
                                                      train=True , download=True, transform=transform)
tstset = torchvision.datasets.CIFAR10(root='%s/Raw' %self.data_path,
                                                      train=False, download=True, transform=transform)
                
trn_load = torch.utils.data.DataLoader(trnset, batch_size=50000, shuffle=False, num_workers=1)
tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=1)

{Horizontal flip, Random cropping.} For each client, we train the model with the local data (i.e., trn_x, trn_y, lines 78-79 in utils_methods.py). Then we will generate the data loader with Dataset() on trn_x, trn_y (line 79 in utils_general.py).

trn_gen = data.DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name), batch_size=batch_size, shuffle=True) 

Then turning back to utils_dataset.py, we do {Horizontal flip, Random cropping.} on the local data if train=True. This means that we train the model with the {Horizontal flip, Random crop} data samples.

if self.train:
    img = np.flip(img, axis=2).copy() if (np.random.rand() > .5) else img # Horizontal flip
    if (np.random.rand() > .5):
    # Random cropping 
        pad = 4
        extended_img = np.zeros((3,32 + pad *2, 32 + pad *2)).astype(np.float32)
        extended_img[:,pad:-pad,pad:-pad] = img
        dim_1, dim_2 = np.random.randint(pad * 2 + 1, size=2)
        img = extended_img[:,dim_1:dim_1+32,dim_2:dim_2+32]

Conclusion. In summary, the preprocess order is 1) normalize and 2) {random crop, random horizontal flip} in the code. In contrast, the typical order is 1) {random crop, random horizontal flip} and 2) normalized. To the best of our knowledge, the typical order matches our intuition (I guess it would be better?). At last, I think it will not affect the experimental conclusion, yet is interesting here.

Hoping for your reply.

Best regards.

Clarification of Code and Algorithm

To further clarify, here are the algorithm steps and their corresponding updates in the code.

  • utils_methods L361 - 369 is about active device selection that determines $\mathcal{P}_t$ in the algorithm.
  • The local objective consists of different parts:
    • utils_general L208 calculates the current stochastic loss which corresponds to $L_k(\theta)$
    • local_param_list_curr or local_grad_vector corresponds to $-\frac{1}{\alpha}\nabla L_k(\theta_k^{t})$. utils_general L219 calculates the 'loss_algo' as $\langle \theta,\theta^{t-1}-\frac{1}{\alpha}\nabla L_k(\theta_k^{t-1})\rangle$. This loss is multiplied by $\alpha$ in total loss calculation.
    • Finally, the qaudractic loss is added through L2 regularization in utils_general L190 which corresponds to $\frac{\alpha}{2}\theta^2$.
  • The total loss becomes $L_k(\theta)+\langle \theta,\alpha \theta^{t-1}-\nabla L_k(\theta_k^{t-1})\rangle + \frac{\alpha}{2}||\theta||^2$. If we add $\frac{\alpha}{2}||\theta^{t-1}||^2$ which does not change the gradients, we obtain the loss as $L_k(\theta)-\langle \theta,\nabla L_k(\theta_k^{t-1})\rangle + \frac{\alpha}{2}||\theta-\theta^{t-1}||^2$.
  • utils_methods L393 updates $-\frac{1}{\alpha}\nabla L_k(\theta_k^{t})$ using the linear equation shown in the algorithm.
  • utils_methods L397 updates the cloud server model $\theta^t$ using the 'h' state and the instantaneous averaging.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.