Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphcast/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
particular, no noise is added to the predictions that are fed back
auto-regressively. Defaults to not adding noise.
gradient_checkpointing: If True, gradient checkpointing will be
used at each step of the computation to save on memory. Roughtly this
used at each step of the computation to save on memory. Roughly this
should make the backwards pass two times more expensive, and the time
per step counting the forward pass, should only increase by about 50%.
Note this parameter will be ignored with a warning if the scan sequence
Expand Down
2 changes: 1 addition & 1 deletion graphcast/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_day_progress(
longitude: 1D array of longitudes at which day progress is computed.

Returns:
2D array of day progress values normalized to be in the [0, 1) inverval
2D array of day progress values normalized to be in the [0, 1) interval
for each time point at each longitude.
"""

Expand Down
12 changes: 6 additions & 6 deletions graphcast/graphcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Reference:
https://arxiv.org/pdf/2202.07575.pdf

It assumes data across time and level is stacked, and operates only operates in
It assumes data across time and level is stacked, and only operates in
a 2D mesh over latitudes and longitudes.
"""

Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(self, model_config: ModelConfig, task_config: TaskConfig):

# Processor, which performs message passing on the multi-mesh.
self._mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
embed_nodes=False, # Node features already embdded by previous layers.
embed_nodes=False, # Node features already embedded by previous layers.
embed_edges=True, # Embed raw features of the multi-mesh edges.
node_latent_size=dict(mesh_nodes=model_config.latent_size),
edge_latent_size=dict(mesh=model_config.latent_size),
Expand All @@ -302,9 +302,9 @@ def __init__(self, model_config: ModelConfig, task_config: TaskConfig):
# Decoder, which moves data from the mesh back into the grid with a single
# message passing step.
self._mesh2grid_gnn = deep_typed_graph_net.DeepTypedGraphNet(
# Require a specific node dimensionaly for the grid node outputs.
# Require a specific node dimensionality for the grid node outputs.
node_output_size=dict(grid_nodes=num_outputs),
embed_nodes=False, # Node features already embdded by previous layers.
embed_nodes=False, # Node features already embedded by previous layers.
embed_edges=True, # Embed raw features of the mesh2grid edges.
edge_latent_size=dict(mesh2grid=model_config.latent_size),
node_latent_size=dict(
Expand Down Expand Up @@ -376,12 +376,12 @@ def __call__(self,
# [num_mesh_nodes, batch, latent_size]
updated_latent_mesh_nodes = self._run_mesh_gnn(latent_mesh_nodes)

# Transfer data frome the mesh to the grid.
# Transfer data from the mesh to the grid.
# [num_grid_nodes, batch, output_size]
output_grid_nodes = self._run_mesh2grid_gnn(
updated_latent_mesh_nodes, latent_grid_nodes)

# Conver output flat vectors for the grid nodes to the format of the output.
# Convert output flat vectors for the grid nodes to the format of the output.
# [num_grid_nodes, batch, output_size] ->
# xarray (batch, one time step, lat, lon, level, multiple vars)
return self._grid_node_outputs_to_prediction(
Expand Down
10 changes: 5 additions & 5 deletions graphcast/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def get_graph_spatial_features(
add_node_positions: Add unit norm absolute positions.
add_node_latitude: Add a feature for latitude (cos(90 - lat))
Note even if this is set to False, the model may be able to infer the
longitude from relative features, unless
latitude from relative features, unless
`relative_latitude_local_coordinates` is also True, or if there is any
bias on the relative edge sizes for different longitudes.
bias on the relative edge sizes for different latitudes.
add_node_longitude: Add features for longitude (cos(lon), sin(lon)).
Note even if this is set to False, the model may be able to infer the
longitude from relative features, unless
Expand Down Expand Up @@ -246,7 +246,7 @@ def get_relative_position_in_receiver_local_coordinates(

The relative positions will be computed in a rotated space for a local
coordinate system as defined by the receiver. The relative positions are
simply obtained by subtracting sender position minues receiver position in
simply obtained by subtracting sender position minus receiver position in
that local coordinate system after the rotation in R^3.

Args:
Expand Down Expand Up @@ -386,7 +386,7 @@ def get_rotation_matrices_to_local_coordinates(
# We want to apply the polar rotation only, but we don't know the rotation
# axis to apply a polar rotation. The simplest way to achieve this is to
# first rotate all the way to longitude 0, then apply the polar rotation
# arond the y axis, and then rotate back to the original longitude.
# around the y axis, and then rotate back to the original longitude.
return transform_.Rotation.from_euler(
"zyz", np_.stack(
[azimuthal_rotation, polar_rotation, -azimuthal_rotation]
Expand Down Expand Up @@ -562,7 +562,7 @@ def get_bipartite_relative_position_in_receiver_local_coordinates(

The relative positions will be computed in a rotated space for a local
coordinate system as defined by the receiver. The relative positions are
simply obtained by subtracting sender position minues receiver position in
simply obtained by subtracting sender position minus receiver position in
that local coordinate system after the rotation in R^3.

Args:
Expand Down