Skip to content

Why function sum can change the dtype of ndl.Tensor #11

@xnuohz

Description

@xnuohz

I met the following error when testing sgd

@data.setter
    def data(self, value):
        assert isinstance(value, Tensor)
>       assert value.dtype == self.dtype, "%s %s" % (
            value.dtype,
            self.dtype,
        )
E       AssertionError: float64 float32

Then I found 1 line in the function compute_gradient_of_variables will cause this error

node.grad = sum(node_to_output_grads_list[node])

I change it and things go right

node_grads = node_to_output_grads_list[node]
node.grad = node_grads[0] if len(node_grads) == 1 else sum(node_grads)

The following dtype in pdb is wired. Maybe I was wrong.

(Pdb) node_grads
[needle.Tensor(1.0)]
(Pdb) node_grads[0].dtype
dtype('float32')
(Pdb) sum(node_grads).dtype
dtype('float64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions