From 8bc53ac62028c9714b9b709ca88c06bda7e5b8ce Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 10 Dec 2024 09:01:01 +0100 Subject: [PATCH 1/5] draft --- rosetta/rosetta/projects/mlip/mace/README.md | 7 + .../projects/mlip/mace/requirements.txt | 3 + rosetta/rosetta/projects/mlip/mace/train.py | 355 ++++++++++++++++++ 3 files changed, 365 insertions(+) create mode 100644 rosetta/rosetta/projects/mlip/mace/README.md create mode 100644 rosetta/rosetta/projects/mlip/mace/requirements.txt create mode 100644 rosetta/rosetta/projects/mlip/mace/train.py diff --git a/rosetta/rosetta/projects/mlip/mace/README.md b/rosetta/rosetta/projects/mlip/mace/README.md new file mode 100644 index 000000000..c9711e858 --- /dev/null +++ b/rosetta/rosetta/projects/mlip/mace/README.md @@ -0,0 +1,7 @@ +# MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields + +## Dependencies + +```bash +pip install -r requirements.txt +``` \ No newline at end of file diff --git a/rosetta/rosetta/projects/mlip/mace/requirements.txt b/rosetta/rosetta/projects/mlip/mace/requirements.txt new file mode 100644 index 000000000..a955e6c89 --- /dev/null +++ b/rosetta/rosetta/projects/mlip/mace/requirements.txt @@ -0,0 +1,3 @@ +cuequivariance-jax +flax +optax \ No newline at end of file diff --git a/rosetta/rosetta/projects/mlip/mace/train.py b/rosetta/rosetta/projects/mlip/mace/train.py new file mode 100644 index 000000000..a8c697492 --- /dev/null +++ b/rosetta/rosetta/projects/mlip/mace/train.py @@ -0,0 +1,355 @@ +from typing import Callable + +import cuequivariance as cue +import cuequivariance_jax as cuex +import flax +import flax.linen +import jax +import jax.numpy as jnp +import numpy as np +from cuequivariance.experimental.mace import symmetric_contraction +from cuequivariance_jax.experimental.utils import MultiLayerPerceptron, gather +import optax + + +class MACELayer(flax.linen.Module): + first: bool + last: bool + num_species: int + num_features: int # typically 128 + interaction_irreps: cue.Irreps # typically 0e+1o+2e+3o + hidden_irreps: cue.Irreps # typically 0e+1o + activation: Callable # typically silu + epsilon: float # typically 1/avg_num_neighbors + max_ell: int # typically 3 + correlation: int # typically 3 + output_irreps: cue.Irreps # typically 1x0e + readout_mlp_irreps: cue.Irreps # typically 16x0e + replicate_original_mace_sc: bool = True + skip_connection_first_layer: bool = False + + @flax.linen.compact + def __call__( + self, + vectors: cuex.IrrepsArray, # [n_edges, 3] + node_feats: cuex.IrrepsArray, # [n_nodes, irreps] + node_species: jax.Array, # [n_nodes] int between 0 and num_species-1 + radial_embeddings: jax.Array, # [n_edges, radial_embedding_dim] + senders: jax.Array, # [n_edges] + receivers: jax.Array, # [n_edges] + ): + dtype = node_feats.dtype + + def lin(irreps: cue.Irreps, input: cuex.IrrepsArray, name: str): + e = cue.descriptors.linear(input.irreps(), irreps) + w = self.param(name, jax.random.normal, (e.inputs[0].irreps.dim,), dtype) + return cuex.equivariant_tensor_product(e, w, input, precision="HIGH") + + def linZ(irreps: cue.Irreps, input: cuex.IrrepsArray, name: str): + e = cue.descriptors.linear(input.irreps(), irreps) + w = self.param( + name, + jax.random.normal, + (self.num_species, e.inputs[0].irreps.dim), + dtype, + ) + # Dividing by num_species for consistency with the 1-hot implementation + return cuex.equivariant_tensor_product( + e, w[node_species], input, precision="HIGH" + ) / jnp.sqrt(self.num_species) + + if True: + i = jnp.argsort(receivers) + senders = senders[i] + receivers = receivers[i] + vectors = vectors[i] + radial_embeddings = radial_embeddings[i] + indices_are_sorted = True + del i + else: + indices_are_sorted = False + + if self.last: + hidden_out = self.hidden_irreps.filter(keep=self.output_irreps) + else: + hidden_out = self.hidden_irreps + + self_connection = None + if not self.first or self.skip_connection_first_layer: + self_connection = linZ( + self.num_features * hidden_out, node_feats, "linZ_skip_tp" + ) + + node_feats = lin(node_feats.irreps(), node_feats, "linear_up") + + messages = node_feats[senders] + sph = cuex.spherical_harmonics(range(self.max_ell + 1), vectors) + e = cue.descriptors.channelwise_tensor_product( + messages.irreps(), sph.irreps(), self.interaction_irreps + ) + e = e.squeeze_modes().flatten_coefficient_modes() + + mix = MultiLayerPerceptron( + [64, 64, 64, e.inputs[0].irreps.dim], + self.activation, + output_activation=False, + with_bias=False, # biases would break smoothness + )( + # embeddings are smooth at cutoff + radial_embeddings + ) + + messages = cuex.equivariant_tensor_product( + e, + mix, + messages, + sph, + algorithm="compact_stacked" if e.all_same_segment_shape() else "sliced", + ) + + node_feats = gather( + receivers, + messages, + node_feats.shape[0], + indices_are_sorted=indices_are_sorted, + ) + node_feats *= self.epsilon + + node_feats = lin( + self.num_features * self.interaction_irreps, node_feats, "linear_down" + ) + + # This is only used in the first layer if it has no skip connection + if self.first and not self.skip_connection_first_layer: + # Selector TensorProduct + node_feats = linZ( + self.num_features * self.interaction_irreps, + node_feats, + "linZ_skip_tp_first", + ) + + e, projection = symmetric_contraction( + node_feats.irreps(), + self.num_features * hidden_out, + range(1, self.correlation + 1), + ) + # The largest STP has 1076 paths. + # Maybe set CUEQUIVARIANCE_MAX_PATH_UNROLL to something larger. + # os.environ["CUEQUIVARIANCE_MAX_PATH_UNROLL"] = "2000" + + n = projection.shape[0 if self.replicate_original_mace_sc else 1] + w = self.param( + "symmetric_contraction", + jax.random.normal, + (self.num_species, n, self.num_features), + dtype, + ) + if self.replicate_original_mace_sc: + w = jnp.einsum("zau,ab->zbu", w, projection) + w = jnp.reshape(w, (self.num_species, -1)) + + node_feats = cuex.equivariant_tensor_product( + e, w[node_species], node_feats, algorithm="sliced" + ) + + node_feats = lin(self.num_features * hidden_out, node_feats, "linear_post_sc") + + if self_connection is not None: + node_feats = ( + node_feats + self_connection + ) # [n_nodes, feature * hidden_irreps] + + node_outputs = node_feats + if self.last: # Non linear readout for last layer + assert self.readout_mlp_irreps.is_scalar() + assert self.output_irreps.is_scalar() + node_outputs = cuex.scalar_activation( + lin(self.readout_mlp_irreps, node_outputs, "linear_mlp_readout"), + self.activation, + ) + node_outputs = lin(self.output_irreps, node_outputs, "linear_readout") + + return node_outputs, node_feats + + +# Just Bessel for now +class radial_basis(flax.linen.Module): + r_max: float + num_radial_basis: int + num_polynomial_cutoff: int = 5 + + def envelope(self, x: jax.Array) -> jax.Array: + p = float(self.num_polynomial_cutoff) + xs = x / self.r_max + xp = jnp.power(xs, self.num_polynomial_cutoff) + return ( + 1.0 + - 0.5 * (p + 1.0) * (p + 2.0) * xp + + p * (p + 2.0) * xp * xs + - 0.5 * p * (p + 1.0) * xp * xs * xs + ) + + def bessel(self, x: jax.Array) -> jax.Array: + bessel_weights = ( + jnp.pi / self.r_max * jnp.arange(1, self.num_radial_basis + 0.5) + ) + prefactor = jnp.sqrt(2.0 / self.r_max) + numerator = jnp.sin(bessel_weights * x) + return prefactor * (numerator / x) + + @flax.linen.compact + def __call__(self, edge: jax.Array) -> jax.Array: + assert edge.ndim == 0 + cutoff = jnp.where(edge < self.r_max, self.envelope(edge), 0.0) + radial = self.bessel(edge) + return radial * cutoff + + +class MACEModel(flax.linen.Module): + num_layers: int + num_features: int + num_species: int + max_ell: int + correlation: int + num_radial_basis: int + interaction_irreps: cue.Irreps + hidden_irreps: cue.Irreps + offsets: np.ndarray + cutoff: float + + @flax.linen.compact + def __call__( + self, + vecs: jax.Array, # [num_edges, 3] + species: jax.Array, # [num_nodes] int between 0 and num_species-1 + senders: jax.Array, # [num_edges] + receivers: jax.Array, # [num_edges] + graph_index: jax.Array, # [num_nodes] + num_graphs: int, + ) -> tuple[jax.Array, jax.Array]: + def model(vecs, node_species, senders, receivers): + with cue.assume(cue.O3, cue.ir_mul): + node_attrs = cuex.as_irreps_array( + jax.nn.one_hot(node_species, self.num_species) + ) + e = cue.descriptors.linear( + node_attrs.irreps(), cue.Irreps(f"{self.num_features}x0e") + ) + w = self.param( + "linear_embedding", + jax.random.normal, + (e.inputs[0].irreps.dim,), + vecs.dtype, + ) + node_feats = cuex.equivariant_tensor_product( + e, w, node_attrs, precision="HIGH" + ) + radial_embeddings = jax.vmap( + radial_basis(self.cutoff, self.num_radial_basis) + )(jnp.linalg.norm(vecs, axis=1)) + vecs = cuex.IrrepsArray("1o", vecs) + + Es = 0 + for i in range(self.num_layers): + first = i == 0 + last = i == self.num_layers - 1 + output, node_feats = MACELayer( + first=first, + last=last, + num_species=self.num_species, + num_features=self.num_features, + interaction_irreps=self.interaction_irreps, + hidden_irreps=self.hidden_irreps, + activation=jax.nn.silu, + epsilon=0.05, + max_ell=self.max_ell, + correlation=self.correlation, + output_irreps=cue.Irreps(cue.O3, "1x0e"), + readout_mlp_irreps=cue.Irreps(cue.O3, "16x0e"), + )( + vecs, + node_feats, + node_species, + radial_embeddings, + senders, + receivers, + ) + Es += jnp.squeeze(output.array, 1) + return jnp.sum(Es), Es + + Fterms, Ei = jax.grad(model, has_aux=True)(vecs, species, senders, receivers) + offsets = jnp.asarray(self.offsets, dtype=Ei.dtype) + Ei = Ei + offsets[species] + + E = jnp.zeros((num_graphs,)).at[graph_index].add(Ei) + + nats = jnp.shape(species)[0] + F = jnp.zeros((nats, 3)).at[senders].add(Fterms).at[receivers].add(-Fterms) + + return E, F + + +def main(): + num_species = 50 + + model = MACEModel( + offsets=np.zeros(num_species), + cutoff=5.0, + num_layers=2, + num_features=192, + num_species=num_species, + interaction_irreps=cue.Irreps(cue.O3, "0e+1o+2e+3o"), + hidden_irreps=cue.Irreps(cue.O3, "0e+1o+2e"), + max_ell=3, + correlation=3, + num_radial_basis=8, + ) + + num_graphs = 128 + num_atoms = 4_000 + num_edges = 70_000 + + vecs = jax.random.normal(jax.random.key(0), (num_edges, 3)) + species = jax.random.randint(jax.random.key(0), (num_atoms,), 0, num_species) + senders, receivers = jax.random.randint( + jax.random.key(0), (2, num_edges), 0, num_atoms + ) + graph_index = jax.random.randint(jax.random.key(0), (num_atoms,), 0, num_graphs) + graph_index = jnp.sort(graph_index) + + w = jax.jit(model.init, static_argnums=(6,))( + jax.random.key(0), vecs, species, senders, receivers, graph_index, num_graphs + ) + opt = optax.adam(1e-2) + opt_state = opt.init(w) + + @jax.jit + def step( + w, opt_state, vecs, species, senders, receivers, graph_index, target_E, target_F + ): + def loss_fn(w): + E, F = model.apply( + w, vecs, species, senders, receivers, graph_index, num_graphs + ) + return jnp.mean((E - target_E) ** 2) + jnp.mean((F - target_F) ** 2) + + grad = jax.grad(loss_fn)(w) + updates, opt_state = opt.update(grad, opt_state) + w = optax.apply_updates(w, updates) + return w, opt_state + + w, opt_state = step( + w, + opt_state, + vecs, + species, + senders, + receivers, + graph_index, + jax.random.normal(jax.random.key(0), (num_graphs,)), + jax.random.normal(jax.random.key(0), (num_atoms, 3)), + ) + + +if __name__ == "__main__": + main() From 489102d6c675dfa88520566bdff4ed5d75f112e8 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 10 Dec 2024 10:04:33 +0100 Subject: [PATCH 2/5] clean --- rosetta/rosetta/projects/mlip/mace/train.py | 120 +++++++++++--------- 1 file changed, 67 insertions(+), 53 deletions(-) diff --git a/rosetta/rosetta/projects/mlip/mace/train.py b/rosetta/rosetta/projects/mlip/mace/train.py index a8c697492..e3f3382dd 100644 --- a/rosetta/rosetta/projects/mlip/mace/train.py +++ b/rosetta/rosetta/projects/mlip/mace/train.py @@ -1,3 +1,4 @@ +import time from typing import Callable import cuequivariance as cue @@ -7,9 +8,9 @@ import jax import jax.numpy as jnp import numpy as np +import optax from cuequivariance.experimental.mace import symmetric_contraction from cuequivariance_jax.experimental.utils import MultiLayerPerceptron, gather -import optax class MACELayer(flax.linen.Module): @@ -31,12 +32,12 @@ class MACELayer(flax.linen.Module): @flax.linen.compact def __call__( self, - vectors: cuex.IrrepsArray, # [n_edges, 3] - node_feats: cuex.IrrepsArray, # [n_nodes, irreps] - node_species: jax.Array, # [n_nodes] int between 0 and num_species-1 - radial_embeddings: jax.Array, # [n_edges, radial_embedding_dim] - senders: jax.Array, # [n_edges] - receivers: jax.Array, # [n_edges] + vectors: cuex.IrrepsArray, # [num_edges, 3] + node_feats: cuex.IrrepsArray, # [num_nodes, irreps] + node_species: jax.Array, # [num_nodes] int between 0 and num_species-1 + radial_embeddings: jax.Array, # [num_edges, radial_embedding_dim] + senders: jax.Array, # [num_edges] + receivers: jax.Array, # [num_edges] ): dtype = node_feats.dtype @@ -93,11 +94,8 @@ def linZ(irreps: cue.Irreps, input: cuex.IrrepsArray, name: str): [64, 64, 64, e.inputs[0].irreps.dim], self.activation, output_activation=False, - with_bias=False, # biases would break smoothness - )( - # embeddings are smooth at cutoff - radial_embeddings - ) + with_bias=False, + )(radial_embeddings) messages = cuex.equivariant_tensor_product( e, @@ -133,10 +131,6 @@ def linZ(irreps: cue.Irreps, input: cuex.IrrepsArray, name: str): self.num_features * hidden_out, range(1, self.correlation + 1), ) - # The largest STP has 1076 paths. - # Maybe set CUEQUIVARIANCE_MAX_PATH_UNROLL to something larger. - # os.environ["CUEQUIVARIANCE_MAX_PATH_UNROLL"] = "2000" - n = projection.shape[0 if self.replicate_original_mace_sc else 1] w = self.param( "symmetric_contraction", @@ -149,15 +143,16 @@ def linZ(irreps: cue.Irreps, input: cuex.IrrepsArray, name: str): w = jnp.reshape(w, (self.num_species, -1)) node_feats = cuex.equivariant_tensor_product( - e, w[node_species], node_feats, algorithm="sliced" + e, + w[node_species], + node_feats, + algorithm="compact_stacked" if e.all_same_segment_shape() else "sliced", ) node_feats = lin(self.num_features * hidden_out, node_feats, "linear_post_sc") if self_connection is not None: - node_feats = ( - node_feats + self_connection - ) # [n_nodes, feature * hidden_irreps] + node_feats = node_feats + self_connection node_outputs = node_feats if self.last: # Non linear readout for last layer @@ -172,7 +167,6 @@ def linZ(irreps: cue.Irreps, input: cuex.IrrepsArray, name: str): return node_outputs, node_feats -# Just Bessel for now class radial_basis(flax.linen.Module): r_max: float num_radial_basis: int @@ -216,6 +210,7 @@ class MACEModel(flax.linen.Module): hidden_irreps: cue.Irreps offsets: np.ndarray cutoff: float + epsilon: float @flax.linen.compact def __call__( @@ -229,21 +224,16 @@ def __call__( ) -> tuple[jax.Array, jax.Array]: def model(vecs, node_species, senders, receivers): with cue.assume(cue.O3, cue.ir_mul): - node_attrs = cuex.as_irreps_array( - jax.nn.one_hot(node_species, self.num_species) - ) - e = cue.descriptors.linear( - node_attrs.irreps(), cue.Irreps(f"{self.num_features}x0e") - ) w = self.param( "linear_embedding", jax.random.normal, - (e.inputs[0].irreps.dim,), + (self.num_species, self.num_features), vecs.dtype, ) - node_feats = cuex.equivariant_tensor_product( - e, w, node_attrs, precision="HIGH" + node_feats = cuex.as_irreps_array( + w[node_species] / jnp.sqrt(self.num_species) ) + radial_embeddings = jax.vmap( radial_basis(self.cutoff, self.num_radial_basis) )(jnp.linalg.norm(vecs, axis=1)) @@ -261,7 +251,7 @@ def model(vecs, node_species, senders, receivers): interaction_irreps=self.interaction_irreps, hidden_irreps=self.hidden_irreps, activation=jax.nn.silu, - epsilon=0.05, + epsilon=self.epsilon, max_ell=self.max_ell, correlation=self.correlation, output_irreps=cue.Irreps(cue.O3, "1x0e"), @@ -290,25 +280,29 @@ def model(vecs, node_species, senders, receivers): def main(): + # Dataset specifications num_species = 50 + num_graphs = 100 + num_atoms = 4_000 + num_edges = 70_000 + avg_num_neighbors = 20 + # Model specifications model = MACEModel( - offsets=np.zeros(num_species), - cutoff=5.0, num_layers=2, - num_features=192, + num_features=128, num_species=num_species, - interaction_irreps=cue.Irreps(cue.O3, "0e+1o+2e+3o"), - hidden_irreps=cue.Irreps(cue.O3, "0e+1o+2e"), max_ell=3, correlation=3, num_radial_basis=8, + interaction_irreps=cue.Irreps(cue.O3, "0e+1o+2e+3o"), + hidden_irreps=cue.Irreps(cue.O3, "0e+1o"), + offsets=np.zeros(num_species), + cutoff=5.0, + epsilon=1 / avg_num_neighbors, ) - num_graphs = 128 - num_atoms = 4_000 - num_edges = 70_000 - + # Dummy data vecs = jax.random.normal(jax.random.key(0), (num_edges, 3)) species = jax.random.randint(jax.random.key(0), (num_atoms,), 0, num_species) senders, receivers = jax.random.randint( @@ -316,20 +310,32 @@ def main(): ) graph_index = jax.random.randint(jax.random.key(0), (num_atoms,), 0, num_graphs) graph_index = jnp.sort(graph_index) + target_E = jax.random.normal(jax.random.key(0), (num_graphs,)) + target_F = jax.random.normal(jax.random.key(0), (num_atoms, 3)) + # Initialization w = jax.jit(model.init, static_argnums=(6,))( jax.random.key(0), vecs, species, senders, receivers, graph_index, num_graphs ) opt = optax.adam(1e-2) opt_state = opt.init(w) + # Training @jax.jit def step( - w, opt_state, vecs, species, senders, receivers, graph_index, target_E, target_F + w, + opt_state, + vecs: jax.Array, + species: jax.Array, + senders: jax.Array, + receivers: jax.Array, + graph_index: jax.Array, + target_E: jax.Array, + target_F: jax.Array, ): def loss_fn(w): E, F = model.apply( - w, vecs, species, senders, receivers, graph_index, num_graphs + w, vecs, species, senders, receivers, graph_index, target_E.shape[0] ) return jnp.mean((E - target_E) ** 2) + jnp.mean((F - target_F) ** 2) @@ -338,17 +344,25 @@ def loss_fn(w): w = optax.apply_updates(w, updates) return w, opt_state - w, opt_state = step( - w, - opt_state, - vecs, - species, - senders, - receivers, - graph_index, - jax.random.normal(jax.random.key(0), (num_graphs,)), - jax.random.normal(jax.random.key(0), (num_atoms, 3)), - ) + for i in range(100): + t0 = time.perf_counter() + w, opt_state = step( + w, + opt_state, + vecs, + species, + senders, + receivers, + graph_index, + target_E, + target_F, + ) + + jax.block_until_ready(w) + t1 = time.perf_counter() + print(f"{i:04}: {t1-t0:.3}s") + + break if __name__ == "__main__": From 43eb48dd6da9aac0544ae2f1f4a2364d5d5488a6 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 10 Dec 2024 10:08:52 +0100 Subject: [PATCH 3/5] readme --- rosetta/rosetta/projects/mlip/mace/README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/rosetta/rosetta/projects/mlip/mace/README.md b/rosetta/rosetta/projects/mlip/mace/README.md index c9711e858..f4fe6d351 100644 --- a/rosetta/rosetta/projects/mlip/mace/README.md +++ b/rosetta/rosetta/projects/mlip/mace/README.md @@ -4,4 +4,10 @@ ```bash pip install -r requirements.txt -``` \ No newline at end of file +``` + +## Run + +```bash +python train.py +``` From 504bc2cd3672c010c005aa5791e6a1599a24fe59 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 10 Dec 2024 14:21:28 +0100 Subject: [PATCH 4/5] clean --- rosetta/rosetta/projects/mlip/mace/train.py | 25 ++++++++++----------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/rosetta/rosetta/projects/mlip/mace/train.py b/rosetta/rosetta/projects/mlip/mace/train.py index e3f3382dd..c888f7ff3 100644 --- a/rosetta/rosetta/projects/mlip/mace/train.py +++ b/rosetta/rosetta/projects/mlip/mace/train.py @@ -222,7 +222,7 @@ def __call__( graph_index: jax.Array, # [num_nodes] num_graphs: int, ) -> tuple[jax.Array, jax.Array]: - def model(vecs, node_species, senders, receivers): + def model(vecs): with cue.assume(cue.O3, cue.ir_mul): w = self.param( "linear_embedding", @@ -231,7 +231,7 @@ def model(vecs, node_species, senders, receivers): vecs.dtype, ) node_feats = cuex.as_irreps_array( - w[node_species] / jnp.sqrt(self.num_species) + w[species] / jnp.sqrt(self.num_species) ) radial_embeddings = jax.vmap( @@ -256,25 +256,24 @@ def model(vecs, node_species, senders, receivers): correlation=self.correlation, output_irreps=cue.Irreps(cue.O3, "1x0e"), readout_mlp_irreps=cue.Irreps(cue.O3, "16x0e"), - )( - vecs, - node_feats, - node_species, - radial_embeddings, - senders, - receivers, - ) + )(vecs, node_feats, species, radial_embeddings, senders, receivers) Es += jnp.squeeze(output.array, 1) return jnp.sum(Es), Es - Fterms, Ei = jax.grad(model, has_aux=True)(vecs, species, senders, receivers) + Fterms, Ei = jax.grad(model, has_aux=True)(vecs) offsets = jnp.asarray(self.offsets, dtype=Ei.dtype) Ei = Ei + offsets[species] - E = jnp.zeros((num_graphs,)).at[graph_index].add(Ei) + E = jnp.zeros((num_graphs,), Ei.dtype).at[graph_index].add(Ei) nats = jnp.shape(species)[0] - F = jnp.zeros((nats, 3)).at[senders].add(Fterms).at[receivers].add(-Fterms) + F = ( + jnp.zeros((nats, 3), Ei.dtype) + .at[senders] + .add(Fterms) + .at[receivers] + .add(-Fterms) + ) return E, F From 48e0a902a3d0808c94b351392dff9851fc4fd4e0 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 11 Dec 2024 12:31:23 -0800 Subject: [PATCH 5/5] fix version --- rosetta/rosetta/projects/mlip/mace/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rosetta/rosetta/projects/mlip/mace/requirements.txt b/rosetta/rosetta/projects/mlip/mace/requirements.txt index a955e6c89..8465d3b52 100644 --- a/rosetta/rosetta/projects/mlip/mace/requirements.txt +++ b/rosetta/rosetta/projects/mlip/mace/requirements.txt @@ -1,3 +1,3 @@ -cuequivariance-jax +cuequivariance-jax==0.1.0 flax optax \ No newline at end of file