Skip to content
This repository was archived by the owner on May 21, 2025. It is now read-only.
This repository was archived by the owner on May 21, 2025. It is now read-only.

GraphsTuple vs. batched GraphsTuple in get_get_graph_padding_mask #55

@thorben-frank

Description

@thorben-frank

Hey,

thanks for this great package. I realized that jraph.get_graph_padding_mask returns jnp.array([False]) when applied to a non-batched GraphsTuple.

I am wondering why this is? Would it be possible to check the length of jraph.GraphsTuple.n_node and return jnp.array([True]) in case it has length 1? Or does this break with some assumptions somewhere else in jraph. Below you find a minimal example.

Thanks and best,
Thorben

import jraph
import jax.numpy as jnp


def get_number_of_graphs(graph):
    """ 
    This function works for GraphsTuple and batched GraphsTuple. 
    For the latter the padding graph(s) are also counted.
    """
    return len(graph.n_node)


def is_batched_bool(graph):
    num_graphs = get_number_of_graphs(graph)
    if num_graphs <= 1:
        return False
    else:
        return True


def modified_get_graph_padding_mask(graph):
    if is_batched_bool(graph) is True:
        return jraph.get_graph_padding_mask(graph)
    else:
        return jnp.array([True])

    
graph = jraph.GraphsTuple(
    nodes=dict(
        atomic_numbers=jnp.ones((10, )),
        positions=jnp.ones((10, 3)),
        z=jnp.ones((10, 3))
    ),
    edges=None,
    receivers=jnp.arange(10),
    senders=jnp.arange(10),
    globals=dict(),
    n_node=jnp.array([10]),
    n_edge=jnp.array([10])
)


print('On unbatched graph')
print('Original version: graph_mask =', jraph.get_graph_padding_mask(graph))
print('Modified version: graph_mask =', modified_get_graph_padding_mask(graph))

batched_graph_iterator = jraph.dynamically_batch([graph, graph], n_node=11, n_edge=11, n_graph=3)
batched_graph = next(batched_graph_iterator)
print('\nOn batched graph')
print('Original version: graph_mask =', jraph.get_graph_padding_mask(batched_graph))
print('Modified version: graph_mask =', modified_get_graph_padding_mask(batched_graph))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions