Skip to content

adds generic training functions for trainer#548

Merged
morganridel merged 2 commits intodevelopfrom
generic-training-functions
Jan 16, 2023
Merged

adds generic training functions for trainer#548
morganridel merged 2 commits intodevelopfrom
generic-training-functions

Conversation

@morganridel
Copy link
Copy Markdown
Collaborator

@morganridel morganridel commented Nov 24, 2022

This allows to pass a custom training function when creating a Trainer, this could be useful for:

  • Custom sampling of the dataset for specific training loops
  • Custom training loop (seems like TF.js fitDataset mostly expect gradient-based optimizers)

This enables users of the lib to write a client with completely custom training, and this allows us to eventually define more training loops that could be used in later tasks.
Currently only the default training is proposed by the lib (TF.js fitDataset), but it can be freely extended by users

Some info about the training function:

The training function is the function that will be called to train the model.
The model is trained locally with the given dataset, which means that it will only have access to the data of the user local machine.
This all happens in the client, so this data is not sent to the server, only the resulting model is.

Every few batches, represented by the round duration ('onRoundEnd' trainer function), the model will be sent to the server and then updated with the new aggregated weights (in the case of the distributed trainer). The following trainings will be done with the updated model.
You need to be aware that the model is subject to regular changes in your training functions.

@morganridel morganridel force-pushed the generic-training-functions branch 3 times, most recently from fa4707c to 2a771c4 Compare November 24, 2022 13:32
@morganridel morganridel marked this pull request as ready for review November 24, 2022 14:35
@martinjaggi
Copy link
Copy Markdown
Member

martinjaggi commented Nov 24, 2022

thanks, looks nice

tiny things:
in the PR description and training function docstring, could you add a comment describing the job of the training function? like saying which data and info it accesses etc.
for instance, say that it doesn't see the global distributed dataset but only the local data on a worker.

also say that onRoundEnd which is every few onBatchEnd, the model is actually replaced with the model from the communication step (federated or decentralized). so the optimizer should be warned that its model gets swapped out regularly (which might slightly confuse adam for example. no prob for SGD)

also now that you made it modular on the training function level, is the trainer builder even needed or a bit overkill? maybe the learning_rate access hack there could also be removed, as the training function (or the optimizer) could access it directly maybe? i don't know

@morganridel
Copy link
Copy Markdown
Collaborator Author

thanks, looks nice

tiny things: in the PR description and training function docstring, could you add a comment describing the job of the training function? like saying which data and info it accesses etc. for instance, say that it doesn't see the global distributed dataset but only the local data on a worker.

also say that onRoundEnd which is every few onBatchEnd, the model is actually replaced with the model from the communication step (federated or decentralized). so the optimizer should be warned that its model gets swapped out regularly (which might slightly confuse adam for example. no prob for SGD)

also now that you made it modular on the training function level, is the trainer builder even needed or a bit overkill? maybe the learning_rate access hack there could also be removed, as the training function (or the optimizer) could access it directly maybe? i don't know

I added a bit of doc to the training function regarding your comment.

For the trainer builder, I'm not really sure right now, I guess we could consider refactoring that in another issue, to keep this one self contained

@morganridel morganridel force-pushed the generic-training-functions branch from be9a174 to 7e44771 Compare December 6, 2022 12:10
@morganridel morganridel force-pushed the generic-training-functions branch from 7e44771 to 7064af1 Compare December 6, 2022 13:20
@morganridel morganridel merged commit 8fedde2 into develop Jan 16, 2023
@morganridel morganridel deleted the generic-training-functions branch January 16, 2023 10:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants