Skip to content

Fix backward implementation and remove setting grad in forward() #945

Open
@H-Huang

Description

@H-Huang

"""
fwd_outputs all forced to have 'requires_grad=True' -- why? what's our design here? freqs_cis could be passed from stage0 to stage1 but is an input value from dataloader and should not require grads.

backward isn't implement correctly afaiu. see rewrite in whc/pp branch, fixes (a) .grad() wont set .grad on W's but .backward will; (b) funny issues with requires-gradness on inputs, disappeared after i simplified
"""

Backward implementation is incorrect as it does not update the gradients of the parameters. Furthermore in forward we should not explicitly set require_grads to true for the inputs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions