You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This allows to pass a custom training function when creating a Trainer, this could be useful for:
Custom sampling of the dataset for specific training loops
Custom training loop (seems like TF.js fitDataset mostly expect gradient-based optimizers)
This enables users of the lib to write a client with completely custom training, and this allows us to eventually define more training loops that could be used in later tasks.
Currently only the default training is proposed by the lib (TF.js fitDataset), but it can be freely extended by users
Some info about the training function:
The training function is the function that will be called to train the model.
The model is trained locally with the given dataset, which means that it will only have access to the data of the user local machine.
This all happens in the client, so this data is not sent to the server, only the resulting model is.
Every few batches, represented by the round duration ('onRoundEnd' trainer function), the model will be sent to the server and then updated with the new aggregated weights (in the case of the distributed trainer). The following trainings will be done with the updated model.
You need to be aware that the model is subject to regular changes in your training functions.
tiny things:
in the PR description and training function docstring, could you add a comment describing the job of the training function? like saying which data and info it accesses etc.
for instance, say that it doesn't see the global distributed dataset but only the local data on a worker.
also say that onRoundEnd which is every few onBatchEnd, the model is actually replaced with the model from the communication step (federated or decentralized). so the optimizer should be warned that its model gets swapped out regularly (which might slightly confuse adam for example. no prob for SGD)
also now that you made it modular on the training function level, is the trainer builder even needed or a bit overkill? maybe the learning_rate access hack there could also be removed, as the training function (or the optimizer) could access it directly maybe? i don't know
tiny things: in the PR description and training function docstring, could you add a comment describing the job of the training function? like saying which data and info it accesses etc. for instance, say that it doesn't see the global distributed dataset but only the local data on a worker.
also say that onRoundEnd which is every few onBatchEnd, the model is actually replaced with the model from the communication step (federated or decentralized). so the optimizer should be warned that its model gets swapped out regularly (which might slightly confuse adam for example. no prob for SGD)
also now that you made it modular on the training function level, is the trainer builder even needed or a bit overkill? maybe the learning_rate access hack there could also be removed, as the training function (or the optimizer) could access it directly maybe? i don't know
I added a bit of doc to the training function regarding your comment.
For the trainer builder, I'm not really sure right now, I guess we could consider refactoring that in another issue, to keep this one self contained
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This allows to pass a custom training function when creating a
Trainer, this could be useful for:This enables users of the lib to write a client with completely custom training, and this allows us to eventually define more training loops that could be used in later tasks.
Currently only the default training is proposed by the lib (TF.js fitDataset), but it can be freely extended by users
Some info about the training function: