{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Equivariant Neural Networks\n",
"\n",
"The previous chapter {doc}`data` discussed data transformation and network architecture decisions that can be made to make a neural network equivariant with respect to translation, rotation, and permutations. However, those ideas limit the expressibility of our networks and are constructed ad-hoc. Now we will take a more systematic approach to defining equivariances and prove that there is only one layer type that can preserve a given equivariance. The result of this section will be layers that can be equivariant with respect to any transform, even for more esoteric cases like points on a sphere or mirror operations. To achieve this, we will need tools from group theory, representation theory, harmonic analysis, and deep learning. Equivariant neural networks are part of a broader topic of **geometric deep learning**, which is learning with data that has some underlying geometric relationships. Geometric deep learning is thus a broad-topic and includes the \"5Gs\": grids, groups, graphs, geodesics, and gauges. However, you'll see papers with that nomenclature concentrated on point clouds (gauges), whereas graph learning and grids are usually called graph neural networks and convolutions neural networks respectively. \n",
"\n",
"```{admonition} Audience & Objectives\n",
"This chapter builds on {doc}`data` and a strong background in math. Although not required, a background on Hilbert spaces, group theory, representation theory, Fourier series, and Lie algebra will help. After completing this chapter, you should be able to \n",
"\n",
" * Derive and understand the mathematical foundations of equivariant neural networks\n",
" * Reason about equivariances of neural networks\n",
" * Know common symmetry groups\n",
" * Implement G-equivariant neural network layers \n",
" * Understand the shape, purpose, and derivation of irreducible function representations\n",
" * Know how weight-constraints can be used as an alternative\n",
"```\n",
"\n",
"```{danger}\n",
"This chapter teaches how to add equivariance for point clouds, but not permutations. To work with multiple molecules of different size/shape, we need to combine ideas from this chapter with permtuation equivariance from the {doc}`gnn` chapter. That combination is explored in {doc}`molnets`. If you're always working with atoms/points in the same order, you can ignore permutation equivariance. \n",
"```\n",
"\n",
"\n",
"## Do you need equivariance?\n",
"\n",
"\n",
"```{margin}\n",
"I'm being a bit unfair, these papers have some slightly different application areas (lie vs compact vs finite groups) and differ mostly in their nonlinearity.\n",
"```\n",
"\n",
"Before we get too far, let me first try to talk you out of equivariant networks. The math required is advanced, especially because the theory of these is still in flux. There are five papers in the last few years that propose a general theory for equivariant networks and they each take a slightly different approach {cite}`finzi2020generalizing,cohen2019general,kondor2018generalization,lang2020wigner,finzi2021emlp`. It is also easy to make mistakes in implementations due to the complexity of the methods. You must also do some of the implementation details yourself, because general efficient implementations of groups is still not solved (although we are [getting close now for specifically SE(3)](https://developer.nvidia.com/blog/accelerating-se3-transformers-training-using-an-nvidia-open-source-model-implementation/)). You will also find that equivariant networks are not in general state of the art on point clouds -- although that is starting to change with recent benchmarks set in point cloud segmentation {cite}`wang2020equivariant`, molecular force field prediction {cite}`batzner2021se3equivariant`, molecular energy predictions {cite}`klicpera2020directional`, and 3D molecular structure generation {cite}`satorras2021en`.\n",
"\n",
"Alternatives to equivariant networks are to just invariant features as discussed in {doc}`data`. Another approach is training and testing augmentation. Both are powerful methods for many domains and are easy to implement {cite}`shorten2019survey`. You can find details in the {doc}`data` chapter. However, augmentation does not work for locally compact symmetry groups (e.g., SO(3)) --- so you cannot use them for rotationally equivariant data. You can do data transformations like discussed in {doc}`data` to avoid equivariance and only work with invariance.\n",
"\n",
"So why would you study this chapter? I think these ideas are important and incorporating the equivariant layers into other network architectures can dramatically reduce parameter numbers and increase training efficiency."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running This Notebook\n",
"\n",
"\n",
"Click the above to launch this page as an interactive Google Colab. See details below on installing packages.\n",
"\n",
"````{tip} My title\n",
":class: dropdown\n",
"To install packages, execute this code in a new cell. \n",
"\n",
"```\n",
"!pip install dmol-book\n",
"```\n",
"\n",
"If you find install problems, you can get the latest working versions of packages used in [this book here](https://github.com/whitead/dmol-book/blob/master/package/requirements.txt)\n",
"\n",
"````"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Outline\n",
"\n",
"We have to lay some mathematical foundations before we can grasp the equations and details of equivariant networks. We'll start with a brief overview of group theory so we can define the principle of equivariance generally. Then we'll show how any equivariance can be enforced in a neural network via a generalization of convolutions. Then we'll visit representation theory to see how to encode groups into matrices. Then we'll see how these convolutions can be more easily represented using the generalization of Fourier transforms. Finally, we'll examine some implementations. Throughout this chapter we'll see three examples that capture some of the different settings."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Group Theory\n",
"\n",
"A modern treatment of group theory can be found in {cite}`zee2016`. You can watch a short fun primer video on group theory from [3Blue1Brown here](https://www.youtube.com/watch?v=mH0oCDa74tE).\n",
"\n",
"A group is a general object in mathematics. A group is a set of elements that can be combined in a binary operation whose output is another member of the group. The most common example are the integers. If you combine two integers in a binary operation, the output is another integer. Of course, it depends on the operation ($1 \\div 2$ does not give an integer), so specifically consider addition. Integers are not the example we care about though. We're interested in groups of **transformations** that move points in a space. Operations like rotation, scaling, mirroring, or translating of single points. As you read about groups here, remember that the elements of the groups are *not* numbers or points. The group elements are transformations that act on points in the space. Notice I'm being a bit nebulous on what the space is for now. Let's first define a group: \n",
"\n",
"```{admonition} Group Definition\n",
"A group $G$ is a set of elements (e.g., $\\{a, b, c, i, e\\}$) equipped with a binary operation ($a\\cdot{}b = c$) whose output is another group element and the following conditions are satisfied:\n",
"\n",
"1. **Closure** The output of the binary operation is always a member of the group\n",
"2. **Associativity** $(a\\cdot{}b)\\cdot{}c = a\\cdot{}(b\\cdot{}c)$\n",
"3. **Identity** There is a single identity element $e$ such that $ex = x \\forall x \\in G$\n",
"4. **Inverse** There exists exactly one inverse element $i$ for each $x$ such that $xi = e$\n",
"```\n",
"\n",
"This is quite a bit of nice structure. We always have an inverse available. Applying the binary operations never accidentally leaves our group. One important property that is missing from this list is **commutativity**. In general, a group is not commutative so that $a\\cdot{}b \\neq b\\cdot{}a$. If the group does have this extra property, we call the group **abelian**. Another detail is how big the set is. It can indeed be infinite, which is why the integers or all possible rotations of rotations of points on a sphere can be represented as a group. One notational convenience we'll make is that the binary operation \"$\\cdot{}$\" will just be referred to as \"dot\" or sometimes multiplication if I get sloppy. The number of elements in a group $|G|$ is known as the **order**.\n",
"\n",
"```{margin}\n",
"If you multiply two transforms $a\\cdot{}b$, we always apply $b$ first and then $a$. This is important to remember for non-commutative groups (non-abelian). \n",
"```\n",
"\n",
"The point of introducing the groups is so that they can transform elements of our space. This is done through a **group action**\n",
"\n",
"```{admonition} Group Action\n",
"A group action $\\pi(g, v)$ is a mapping from a group $G$ and a space $\\mathcal{X}$ to the space $\\mathcal{X}$:\n",
"\n",
"\\begin{equation}\n",
"\\pi: G\\times \\mathcal{X}\\rightarrow \\mathcal{X}\n",
"\\end{equation}\n",
"```\n",
"\n",
"```{margin} function arrow\n",
"$G\\times \\mathcal{X}$ means there are two input arguments, one from group $G$ and one from space $\\mathcal{X}$. $\\rightarrow \\mathcal{X}$ shows our function outputs a value in the space $\\mathcal{X}$.\n",
"```\n",
"\n",
"So a group action takes in two arguments (binary): a group element and a point in a space $\\mathcal{X}$ and transforms the point to a new one: $\\pi(g, x_0) = x_1$. This is just a more systematic way of saying it transforms a point. The group action is neither unique to the space nor group. Often we'll omit the function notation for the group action and just write $gx = x'$."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's introduce our three example groups that we'll refer to throughout this chapter.\n",
"\n",
"\n",
"### ⬡ Finite Group $Z_6$ \n",
"\n",
"The first group is about rotations of a hexagon {glue:}`hex-6`. Our basic group member will be rotating the hexagon enough to shift all the vertices: {glue:}`hex-0` $\\rightarrow$ {glue:}`hex-1`. Notice I've colored the vertices and added a line so we can easily distinguish the orientation of the hexagon. Remember the hexagon, its colors, and if it is actually symmetric have nothing to do with the group. *The group elements are transformations we apply to the hexagon*.\n",
"\n",
"One group action for this example can use modular arithmetic. If we represent a point in our space as $\\left\\{0,\\ldots, 5\\right\\}$ then the rotation transformation is $x' = x + 1 \\;(\\bmod\\; 6)$. For example, if we start at $5$ and rotate, we get back to $0$. \n",
"\n",
"Our group must contain our rotation transformation $r$ and the identity: $\\{e, r\\}$. This set is not closed though: rotating twice $r\\cdot{}r$ {glue:}`hex-0` $\\rightarrow$ {glue:}`hex-1` $\\rightarrow$ {glue:}`hex-2` gives a new group element $r^2$. To close the group we need to have $\\{e, r, r^2, r^3, r^4, r^5\\}$. \n",
"\n",
"Is this closed? Consider rotating twice and then five times $r^5\\cdot{}r^2$ {glue:}`hex-0` $\\rightarrow$ {glue:}`hex-2` $\\rightarrow$ {glue:}`hex-1` You can see that this is the same as $r$, so $r^5\\cdot{}r^2 = r$. What about the inverses element? The inverse of $r$ is $r^5$. $r\\cdot{}r^5 = e$. You can indeed see that each element has an inverse ($e$ is its own inverse).\n",
"\n",
"In general, we can write out the group as a multiplication table that conveys all group elements and defines the output of all binary outputs:\n",
"\n",
"$$\n",
"\\begin{array}{l|cccccr}\n",
"& e & r & r^2 & r^3 & r^4 & r^5\\\\\n",
"\\hline\n",
"e & e & r & r^2 & r^3 & r^4 & r^5\\\\\n",
"r & r & r^2 & r^3 & r^4 & r^5 & e\\\\\n",
"r^2 & r^2 & r^3 & r^4 & r^5 & e & r\\\\\n",
"r^3 & r^3 & r^4 & r^5 & e & r & r^2\\\\\n",
"r^4 & r^4 & r^5 & e & r & r^2 & r^3\\\\\n",
"r^5 & r^5 & e & r & r^2 & r^3 & r^4\\\\\n",
"\\end{array}\n",
"$$\n",
"\n",
"You can also see that the group is abelian (commutative). For example, $r\\cdot{}r^3 = r^3\\cdot{}r$.\n",
"\n",
"This kind of table is called a [**Cayley table**](https://en.wikipedia.org/wiki/Cayley_table). Although it doesn't matter for this example, we'll see later that the order of look-up matters. Specifically if our group is non-abelian. *The row factor comes first and the column factor second*. So $r\\cdot{}r^5$ means we look at row $r$ and column $r^5$ to get the group element, which in this case is $e$.\n",
"\n",
"This group of rotations is an example of a **cyclic group** and is isomorphic (same transformations, but operates on different objects) to integers modulo 6. Meaning, you could view rotation $r^n$ as operating on integers $(x + n) \\textrm{mod}\\, 6$. Cyclic groups are written as $Z_n$, so this group is $Z_6$. \n",
"\n",
"\n",
"### ▩ p4m\n",
" \n",
"```{margin}\n",
"p4m strictly speaking only includes integer translations but many of the principles apply for continuous infinite groups (locally compact) and integer (countably) infinite groups\n",
"```\n",
"\n",
"The second group contains translation, 90° rotations, and horizontal/vertical mirroring. We're now operating on real numbers $x,y$, so we're in $\\mathbb{R}^2$. Let's ignore the translation for now and just consider mirroring ($s$) and rotation by 90° ($r$) about the origin. What powers of $r$ and $s$ do we need to have a closed group? Considering rotations alone first, like last time we should only need up to $r^3$. Here are the rotations visually: {glue:}`quad-0-0`, {glue:}`quad-1-0`, {glue:}`quad-2-0`, {glue:}`quad-3-0` What about mirroring on horizontal/vertical? Mirroring along the horizontal axis: {glue:}`quad-0-0` $\\rightarrow$ {glue:}`quad-2-1` is actually the same as rotating twice and then mirroring along the vertical. In fact, you only need to have mirroing along one axis. We'll choose the vertical axis by convention and denote that as $s$.\n",
"\n",
"We can build the group action piece by piece. The group action for rotation can be represented as a 2D rotation matrix acting a point $(x, y)$:\n",
"\n",
"$$\n",
"\\left[\\begin{array}{lr}\n",
"\\cos\\frac{k2\\pi}{4} & -\\sin\\frac{k2\\pi}{4}\\\\\n",
"\\sin\\frac{k2\\pi}{4} & \\cos\\frac{k2\\pi}{4}\\\\\n",
"\\end{array}\\right]\n",
"\\left[\\begin{array}{c}\n",
"x\\\\\n",
"y\\\\\n",
"\\end{array}\\right]\n",
",\\, k \\in \\left\\{0, 1, 2, 3\\right\\}\n",
"$$\n",
"\n",
"where $k$ can allow us to do two rotations at once ($k = 2$) or the identity ($k = 0$). The vertical axis mirror action can be represented by \n",
"\n",
"$$\n",
"\\left[\\begin{array}{lr}\n",
"-1 & 0\\\\\n",
"0 & 1\\\\\n",
"\\end{array}\\right]\n",
"\\left[\\begin{array}{c}\n",
"x\\\\\n",
"y\\\\\n",
"\\end{array}\\right]\n",
"$$\n",
"\n",
"These two group actions can be ordered to correctly represent rotation then mirroring or vice-versa.\n",
"\n",
"\n",
"Now is this closed with the group elements $\\{e, r, r^2, r^3, s\\}$? Visually we have {glue:}`quad-0-0`, {glue:}`quad-1-0`, {glue:}`quad-2-0`, {glue:}`quad-3-0`, {glue:}`quad-0-1`? No. Consider $r^2\\cdot{}s$ {glue:}`quad-0-0` $\\rightarrow$ {glue:}`quad-0-1` $\\rightarrow$ {glue:}`quad-2-1` which is not an element. To close the group, we need $\\{e, r, r^2, r^3, s, rs, r^2s, r^3s\\}$. The multiplication table (which defines the elements too) is:\n",
"\n",
"$$\n",
"\\begin{array}{l|cccccccr}\n",
"& e & r & r^2 & r^3 & s & rs & r^2s & r^3s\\\\\n",
"\\hline\n",
"e & e & r & r^2 & r^3 & s & rs & r^2s & r^3s\\\\\n",
"r & r & r^2 & r^3 & e & rs & r^2s & r^3s & s\\\\\n",
"r^2 & r^2 & r^3 & e & r & r^2s & r^3s & s & rs\\\\\n",
"r^3 & r^3 & e & r & r^2 & r^3s & s & rs & r^2s\\\\\n",
"s & s & r^3s & r^2s & rs & e & r^3 & r^2 & r\\\\\n",
"rs & rs & s & r^3s & r^2s & r & e & r^3 & r^2\\\\\n",
"r^2s & r^2s & rs & s & r^3s & r^2 & r & e & r^3\\\\\n",
"r^3s & r^3s & r^2s & rs & s & r^3 & r^2 & r & e\\\\\n",
"\\end{array}\n",
"$$\n",
"\n",
"This is a [**Cayley table**](https://en.wikipedia.org/wiki/Cayley_table). Remember *The row factor comes first and the column factor second*. So $rs\\cdot{}r^3$ means we look at row $rs$ and column $r^3$ to get the group element, which in this case is $r^2s$.\n",
"\n",
"As you can see from the Cayley table, the group is closed. Remember, elements like $rs$ are not a binary operation. They are group elements, hence the missing binary operation symbol. We also see that the group is not commutative. $r\\cdot{}s$ is {glue:}`quad-0-0` $\\rightarrow$ {glue:}`quad-0-1` $\\rightarrow$ {glue:}`quad-3-1`, so $r\\cdot{}s = rs$ as expected. However, $s\\cdot{}r$ is {glue:}`quad-0-0` $\\rightarrow$ {glue:}`quad-1-0` $\\rightarrow$ {glue:}`quad-1-1`, which is the group element $r^3s$.\n",
"\n",
"We can also read the inverses off the table. For example, the inverse of $r$ is the column which gives the identity element: $r^3$. This group is known as the dihedral group 4 $D_4$. It has order 8. \n",
"\n",
"Now consider the translation group elements. For simplicity, let's only consider integer translations. We can label them as $t_{w,h}$. So $t_{3,4}$ means translate by $x + 3$ and $y + 4$. Is this a proper group? Certainly it associative, there is an identity $t_{0,0}$ and an inverse for each element $t_{-x, -y}$. What about closure? Yes, since translating twice is equivalent to one larger translation: $t_{w,h}\\cdot{}t_{w', h'} = t_{w + w', h + h'}$. This expression also shows group action for translation. \n",
"\n",
"What about when we combine with our other elements from the $D_4$ group? Consider the product $r\\cdot{}t_{3,4}$. This means translating by $(3,4)$ and then rotating by 90° about the origin. If you consider this acting on a single point $(0,0)$, you could get $(0,0) \\rightarrow (3,4) \\rightarrow (-4,3)$. What element of our group would this represent? At first it seems like it could be $t_{-3,4}$. However, $t_{-3,4}$ would only work specifically for starting at $(0,0)$. If you started at $(1,1)$, you would get to $(-4,5)$ with $r\\cdot{}t_{3,4}$ and $(-2,5)$ with $t_{-3,4}$. To be correct for *any point*, we need a different group element. So the product $r\\cdot{}t_{3,4}$ actually cannot be a product but instead must be a group element. In fact, our new combined group is just going to be $ab$ where $a$ is an element from $D_4$ and $b$ is a translation. Thus $r\\cdot{}t_{3,4} = rt_{3,4}$. \n",
"\n",
"Combing these two groups, the translation and $D_4$, is an example of a **semidirect product**. A semidirect product just means that we create a new group by combining all possible group elements. There is some machinery for this, like the identity element in our new group is something like $et_{0,0}$, and it has some other structure. It is called semidirect, instead of direct, because we can actually mix our group elements. The elements both act on points in the same space ($x,y$ plane), so this makes sense. Another condition is that we can only have a semidirect product when one subgroup is normal and the translation subgroup is the normal subgroup. It is coincidentally abelian, but these two properties are not always identical. This semidirect product group is called p4m.\n",
"\n",
"Below, is an optional section that formalizes the idea of combining these two groups into one larger group.\n",
"\n",
"\n",
"```{admonition} Normal Subgroup\n",
"A normal subgroup is a group of elements $n$ from the group $G$ called $N$. Each $n \\in N$ should have the property that $g\\cdot{}n\\cdot{}g^{-1}$ gives an element in $N$ for any $g$. \n",
"```\n",
"\n",
"This does not mean $g\\cdot{}n\\cdot{}g^{-1} = n$, but instead that $g\\cdot{}n\\cdot{}g^{-1} = n'$ where $n'$ is some other element in $N$. For example, in p4m the translations form a normal subgroup. Rotating, translating, then doing the inverse of the rotation is equivalent to some translation. Notice that $D_4$ is not a normal subgroup of p4m. If you do an inverse translation, rotate, then do a translation you may not have something equivalent to a rotation. It may be strange that we're talking about the group p4m when we haven't yet described how it's defined (identity, inverse, binary op). We'll do that with the semidirect product and then we could go back and verify that the translations are a normal subgroup more rigorously. I do not know the exact connection, but it seems that normal subgroups are typically abelian. \n",
"\n",
"```{admonition} Semidirect Product\n",
"Given a normal subgroup of $G$ called $N$ and a subgroup $H$, we can define $G$ using the semidirect product. Each element in $G$ is a tuple of two elements in $N, H$ written as $(n, h)$. The identity is $(e_n, e_h)$ and the binary operation is:\n",
"\n",
"\\begin{equation}\n",
"(n_1, h_1) \\cdot (n_2, h_2) = (n_1\\cdot\\phi(h_1)(n_2), h_1\\cdot{}h_2)\n",
"\\end{equation}\n",
"\n",
"where $\\phi(h)(n)$ is the conjugation of $n$ $\\phi(h)(n) = h\\cdot{}n\\cdot{}h^{-1}$. When a transform $(n,h)$ is applied, we follow the normal convention that $h$ is applied first followed by $n$. \n",
"```\n",
"\n",
"We are technically doing an outer semidirect product: combining them under the assumption that both $D_4$ and $T$ are part of a larger group which contains both. This is a bit of a semantic detail, but they are actually both part of $p4m$ and a larger group called the affine group which includes, rotation, shear, translation, mirror, and scale operations on points. You could also argue they are part of groups which can be represented by 3x3 invertible matrices. Thus, you can combine these and get something that is still smaller than their larger containing group ($p4m$ is smaller than all affine transformations).\n",
"\n",
"One consequence of the semidirect product is that if you have a group element $(n,h)$ but want to instead apply $n$ first (instead of $h$), you can use the binary operation:\n",
"\n",
"\n",
"\\begin{equation}\n",
"(e_n, h) \\cdot (n, e_h) = (e_n\\cdot\\phi(h)(n), h\\cdot{}e_h) = (\\phi(h)(n), h)\n",
"\\end{equation}\n",
"\n",
"\n",
"so $\\phi(h)(n)$ somehow captures the effect of switching the order applying elements from $H$ and $N$. In our case, this means swapping the order of rotation/mirroring and translation. \n",
"\n",
"To show what effect the semidirect product has in p4m, we can clean-up our example above about $r\\cdot{}t_{3,4}$. We should write the first element of this binary product $r$ as a tuple of group elements: one from the $D_4$ and one from the translations. Since there is no translation for $r$, we use the identity. Thus we write $r$ as $(t_{0,0}, r)$ in our semidirect product group p4m. Note that the normal subgroup comes first (applied last) by convention. Similarly, $t_{3,4}$ is written as $(t_{3,4}, e)$. Our equation becomes:\n",
"\n",
"$$\n",
"(t_{0,0}, r)\\cdot(t_{3,4}, e) = (t_{0,0}\\cdot\\phi(r)(t_{3,4}), r\\cdot{}e) = (t_{0,0}\\cdot\\phi(r)(t_{3,4}), r)\n",
"$$\n",
"\n",
"where $\\phi$ is the automorphism that distinguishes a semidirect product from a direct product. The direct product has $\\phi(h)(n) = n$ so that the binary operation for the direct product group is just the element-wise binary products. $\\phi(h)(n) = hnh^{-1}$ for semidirect products. In our equation, this means $\\phi(r)(t_{3,4}) = r\\cdot{}t_{3,4}\\cdot{}r^3$. Substituting this and using the fact that both groups have the same binary operation (matrix multiplication, as we'll see shortly):\n",
"\n",
"$$\n",
"(t_{0,0}\\phi(r)(t_{3,4}), r) = (r\\cdot{}t_{3,4}\\cdot{}r^3, r) = r\\cdot{}t_{3,4}\\cdot{}r^3\\cdot r = r\\cdot{}t_{3,4}\n",
"$$\n",
"\n",
"Thus we've proved that translating by $3,4$ followed by rotating can be expressed as $r\\cdot{}t_{3,4}$, which seems like a lot of work for an obvious result. I won't cover the semidirect product of the group action, but we'll see that we do not necessarily need to build a group action encapsulating both translation and rotation/mirroring. \n",
"\n",
"\n",
"\n",
"### ⚽ SO(3) Group\n",
"\n",
"SO(3) is the group for analyzing 3D point clouds like trajectories or crystal structures (with no other symmetries). SO(3) is the group of all rotations about the origin in 3D. The group is non-abelian because rotations in 3D are not commutative. The group order is infinite, because you can rotate in this group by any angle (or sets of angles). If you are interested in allowing translations, you can use SE(3) which is the semidirect product of SO(3) and the translation group (like p4m), which is a normal subgroup. \n",
"\n",
"The SO(3) name is a bit strange. SO stands for \"special orthogonal\" which are two properties of square matrices. In this case, the matrices are $3\\times3$. Orthogonal means the columns sum to one and special means the determinant is 1. Interestingly, all rotations in 3D around the origin are also the SO(3) matrices. \n",
"\n",
"One detail is that since we're rotating (no scale or translation) the distance to origin will not change. We cannot move the radius. The group action is the product of 3 3D rotation matrices (using [Euler angles](https://en.wikipedia.org/wiki/Euler_angles)) $R_z(\\alpha)R_y(\\beta)R_z(\\gamma)$ where $\\alpha,\\gamma \\in [0, 2\\pi], \\beta \\in [0, \\pi]$ and\n",
"\n",
"$$\n",
"R_z(\\theta) = \\left[\\begin{array}{lcr}\n",
"\\cos\\theta & -\\sin\\theta & 0\\\\\n",
"\\sin\\theta & \\cos\\theta & 0\\\\\n",
"0 & 0 & 1\\\\\n",
"\\end{array}\\right]\n",
"$$\n",
"\n",
"$$\n",
"R_y(\\theta) = \\left[\\begin{array}{lcr}\n",
"\\cos\\theta & 0 & \\sin\\theta \\\\\n",
"0 & 1 & 0\\\\\n",
"-\\sin\\theta & 0 & \\cos\\theta\\\\\n",
"\\end{array}\\right]\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Groups on Spaces\n",
"\n",
"We've defined transforms and their relationships to one another via group theory. Now we need to actually connect the transforms to a space. It is helpful to think about the space as Euclidean with a concept of distance and coordinates, but we'll see that this is not required. Our space could be vertices on a graph or integers or classes. There are *some* requirements though. The first is that our space must be **homogeneous** (for the purposes of this chapter). Homogeneous means that from any point in our space $x$ we can reach any other point with a transform $g$ from our group $G$. The second requirement is that if our group is infinite, the space must **locally compact**. This is a concept from topology and we won't really ever be troubled by it. Most spaces we'll see in chemistry or materials science (Euclidean spaces) are locally compact.\n",
"\n",
"```{margin} Lie group\n",
" If the group transforms are further smooth and have smooth inverses, the group (and associated space) are called a **lie group**. \n",
"\n",
"```\n",
"\n",
"```{tabbed} ⬡ Finite Group $Z_6$ \n",
"\n",
"The space is homogeneous because our group includes \"compound\" rotations like $r^4$. This is a finite group, so we do not require the space to be compact. \n",
"\n",
"```\n",
"\n",
"```{tabbed} ▩ Locally Compact p4m\n",
"\n",
"The space is homogeneous since we can use a translation to get to any other point. The space is locally compact because we are in 2D Euclidean geometry.\n",
"\n",
"```\n",
"\n",
"```{tabbed} ⚽ SO(3) Group\n",
"\n",
"The space is homogeneous because we restrict ourselves to being on the sphere. The space is locally compact because we are in 3D Euclidean geometry.\n",
"\n",
"```\n",
"\n",
"The requirement of space being homogeneous is fairly strict. It means we cannot work with $\\mathbb{R}^2$ with a finite group like mirror and fixed rotations (i.e., p4m without translations). For example, going from $x = (0,0)$ to $x = (1,1)$ cannot be done with rotations/mirror group elements alone. As you can see, working in a Euclidean space thus requires a locally compact group. Similarly, a finite group implies a finite space because of the homogeneous requirement. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This may seem like a ton of work. We could have just started with $xyz$ coordinates and rotation matrices. Please continue to wait though, we're about to see something incredible.\n",
"\n",
"## Equivariance Definition\n",
"\n",
"You should be thinking now about how we can define equivariance using our new groups. That's where we're headed. We need to do a bit of work now to \"lift\" neural networks and our features into the framework we're building. First, in {doc}`data` we defined our features as being composed of tuples $(\\vec{r}_i, \\vec{x}_i)$ where $\\vec{r}_i$ is a spatial point and $\\vec{x}_i$ are the features at that point. Let's now view these input data as functions, defined as $f(\\vec{r}) = \\vec{x}$ and assume if a point $\\vec{r}'$ isn't in our training data then $f(\\vec{r}') = \\vec{0}$. More formally, our training data is a function $f:\\mathcal{X} \\rightarrow \\mathbb{R}^n$ that maps from our homogeneous space $\\mathcal{x}$ to real vector (or complex vectors) of dimension $n$. \n",
"\n",
"We have promoted our data into a function and now a neural network can no longer be just function since its input is a function. Our neural network will be also promoted to a **linear map**, which has an input of a function and an output of a function. Formally, our neural network is now $\\psi: f(\\mathcal{X}) \\rightarrow f'(\\mathcal{X})$. Notice the input and output spaces of the functions may not be the same (we may input a molecule 3D and output a 1D scalar for energy). Linear maps are also called **operators**, depending on which branch of mathematics you're in.\n",
"\n",
"The last piece of equivariance is to promote our group elements, which transform points, to work on functions.\n",
"\n",
"```{admonition} G-Function Transform Definition\n",
"An element $g$ of group $G$ on the homogeneous space $\\mathcal{X}$ can act on a function $f:\\mathcal{X}\\rightarrow \\mathbb{R}^n$ via the group transform linear map $\\mathbb{T}_g: f(\\mathcal{X}) \\rightarrow f'(\\mathcal{X})$ defined as \n",
"\n",
"\\begin{equation}\n",
"f'(gx) = f(x) \\Rightarrow f'(x) = f(g^{-1}x)\n",
"\\end{equation}\n",
"```\n",
"\n",
"This definition takes a moment to think about. Consider a translation of an image. You want to move an image to the left by 10 pixels, so $g = t_{10,0}$. The image is defined by the function $f(x,y) = (r, g, b)$, where $r,g,b$ is the color. We want $T_g f(x, y)$. Without knowing about groups, you can intuit that translating can be done by creating a new function $f'(x', y') = f(x - 10, y)$. Notice that the inverse of $g^{-1} = t_{-10, 0}$ acts on the points, not $g$. Recall that a group requires there to be an inverse for any group element.\n",
"\n",
"Now we have all the pieces to define an equivariant neural network:\n",
"\n",
"```{admonition} Equivariant Neural Network Definition\n",
"Given a group $G$ that has actions on two homogeneous space $\\mathcal{X_1}$ and $\\mathcal{X_2}$, a G-equivariant neural network is a linear map $\\psi: f(\\mathcal{X_1}) \\rightarrow f'(\\mathcal{X_2})$ that has the property{cite}`kondor2018generalization`:\n",
"\n",
"\\begin{equation}\n",
"\\psi\\left[\\mathbb{T}_g f(x)\\right] = \\mathbb{T'}_{g}\\psi\\left[f(x)\\right]\\;\\forall\\, f(x)\n",
"\\end{equation}\n",
"\n",
"where $\\mathbb{T}_g$ and $\\mathbb{T}'_g$ are G-function transforms on the two spaces. If $\\mathbb{T}'_g = \\textrm{id}$, meaning the transform is the identity in the output space regardless of $g$, then $\\psi$ is a G-invariant neural network.\n",
"```\n",
"\n",
"The definition means that we get the same output if we transform the input function to the neural network or transform the output (in the equivariant case). In a specific example, if we rotate the input by 90 degrees, that's the same result as rotating the output by 90 degrees. Take a moment to ensure that matches your idea of what equivariance means. After all this math, we've generalized equivariance to arbitrary spaces and groups.\n",
"\n",
"What the two input and output spaces? It's easiest to think about them as the same space for equivariant neural networks. For an invariant, the output space is typically a scalar. Another example for an invariant one could be aligning a molecular structure to a reference. The neural network should align to the same reference regardless of how the input is transformed. \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## G-Equivariant Convolution Layers\n",
"\n",
"Recall that a neural network is made-up of a linear part (e.g., $\\vec{h} = \\mathbf{W}\\vec{x} + \\vec{b}$) and a non-linearity. Kondor and Trivedi showed that there is *only one* way to make a G-equivariant neural network is to make the linear part:\n",
"\n",
"```{admonition} G-Equivariant Convolution Theorem\n",
"A neural network layer (linear map) $\\psi$ is G-equivariant if and only if its form is a convolution operator $*$ \n",
"\n",
"$$\n",
"\\psi(f) = (f * \\omega)(u) = \\sum_{g \\in G} f\\uparrow^G\\left(ug^{-1}\\right)\\omega\\uparrow^G\\left(g\\right)\n",
"$$ (disc-conv)\n",
"\n",
"where $f: H \\rightarrow \\mathbb{R}^n$ and $\\omega: H' \\rightarrow \\mathbb{R}^n$ are functions of quotient spaces $H$ and $H'$. If the group $G$ is locally compact (infinite elements), then the convolution operator is\n",
"\n",
"\\begin{equation}\n",
"\\label{cont-conv}\n",
"\\psi(f) = (f * \\omega)(u) = \\int_G f\\uparrow^G\\left(ug^{-1}\\right)\\omega\\uparrow^G\\left(g\\right)\\,d\\mu(g)\n",
"\\end{equation}\n",
"\n",
"where $\\mu$ is the group Haar measure. A [Haar measure](https://en.wikipedia.org/wiki/Haar_measure) is a generalization of the familiar integrand factor you see when doing integrals in polar coordinates or spherical coordinates.\n",
"\n",
"```\n",
"\n",
"```{margin}\n",
"This is a strong theorem. It says there is only one way to achieve equivariance in a neural network. This may seem counter-intuitive since there are many competing approaches to convolutions. These other approaches are actually equivalent to a convolution; just it can be hard to notice. \n",
"```\n",
"\n",
"As you can see from the theorem, we must introduces more new concepts. The first important detail is that all our functions are over our group elements (technically the quotient space $G / H_0$), not our space. This should seem strange. We will easily fix this because there is a (bijective) way to assign one group element to each point in the space. The second detail is the $f \\uparrow^G$. The order of the group $G$ is greater than or equal to the number of points in our space, so if the function is defined on our space, we must \"lift\" it up to the group $G$ which has more elements. The last detail is the point about **quotient spaces**. Quotient spaces are how we cut-up our group $G$ into subgroups so that one has the same order as the number of points in our space. Below I detail these new concepts just enough so that we can implement and understand these convolutions.\n",
"\n",
"```{warning}\n",
"To actually learn, you need to put in a nonlinearity after the convolution. A simple (and often used) case is to just use a standard non-linear function like ReLU pointwise (applied to each term in $g \\in G$ sum individually). We'll look at more complex examples below for the continuous case.\n",
"```\n",
"\n",
"## Converting between Space and Group\n",
"\n",
"Let's see how we can convert between functions on the space $\\mathcal{X}$ and functions on the group $G$. $|G| \\geq |\\mathcal{X}|$ ($|G|$ is number of elements) because the space is homogeneous, so it is rare that we can uniquely replace each point in space with a group in $G$. Instead, we will construct a partitioning of $G$ into $|\\mathcal{X}|$ sets called a quotient space $G / H$ such that $|G / H| = |\\mathcal{X}|$. It turns out, there is a well-studied approach to arranging elements in a group called **cosets**. Constructing cosets is a two-step process. First we define a subgroup $H$. A **subgroup** means it is itself a group; it has identities and inverses. We cannot accidentally leave $H$, $h_1\\cdot{} h_2 \\in H$. For example, translation transformations are a subgroup because you cannot accidentally create a rotation when combining two translations. \n",
"\n",
"```{margin}\n",
"This process of constructing cosets and then using that to lift our function is closely related to the process of finding an induced representation on $G$ via a representation on $H$.\n",
"```\n",
"\n",
"After constructing a subgroup $H$, we can apply an element $g$ to every element in $H$, written as \n",
"\n",
"\\begin{equation}\n",
"gH = \\left\\{g \\cdot h \\forall h \\in H\\right\\}\n",
"\\end{equation}\n",
"\n",
"If this sounds strange, wait for an example. $gH$ is called a **left coset**. We mention the direction because $G$'s binary operation may not be commutative (non-abelian). What happens if $g$ is in $H$? No problem; $H$ is a group so applying an element to every element in $H$ just gives back $H$ (i.e. $hH = H$). Cosets are not groups, they are definitely not closed or have inverses. What's the point of making all these cosets? Remember our goal is to partition $G$ into a bunch of smaller sets so that we have one for each point in $\\mathcal{X}$. Constructing cosets partitions $G$ for sure, but do we get enough? Could we accidentally have overlaps between cosets, where $g_1H$ and $g_2H$ contain the same elements? \n",
"\n",
"```{margin}\n",
"If your group involves rotations, make life easy on yourself and always choose $x_0$ as the origin (or center of the rotations). \n",
"```\n",
"\n",
"It turns out if our space is homogeneous we can construct our cosets in a special way so that we have exactly one coset for each point in the space $\\mathcal{X}$. To get our group, we pick an arbitrary point in the space $x_0$. Often this will be the origin. Then we choose our subgroup $H$ to be all group elements that leave $x_0$ unchanged. This is called a stabilizer subgroup $H_0$ and is defined as \n",
"\n",
"\\begin{equation}\n",
"H_0 = \\left\\{ g \\in G \\,\\textrm{such that}\\, g x_0 = x_0\\right\\}\n",
"\\end{equation}\n",
"\n",
"We will not prove that this is a group itself. This defines our subgroup. Here's the remarkable thing: we will have exactly enough cosets with this stabilizer as there are points in $\\mathcal{X}$. However, multiple $g$s will give the same coset (as expected, since $|G| > |\\mathcal{X}|$).\n",
"\n",
"```{margin} \n",
"This set of all cosets is itself a group and it is written as $G / H_0$. The fact that the cosets is a group is just weird. What is the identity coset? How do you define binary operations on cosets? It turns out we do not need these items but it is fascinating.\n",
"```\n",
"\n",
"Now comes the details, how do we match-up points in $\\mathcal{X}$ to the cosets? We know that the space is homogeneous so each point in $x$ can be reached from our arbitrary origin by a group element $gx_0 = x$. That's one way to connect points to group elements, but which coset will $g$ be in? There also may be multiple $g$s that satisfy the equation. It turns out that all the group elements that satisfy the equation will be in the same coset. The reason why is that $g\\cdot h x_0 = gx_0$ because all elements $h$ of the stabilizer group do not move $x_0$. Quite elegant. \n",
"\n",
"How do we find which coset we need? Since the identity $e$ is in $H_0$ (by definition), the coset $gH_0$ will contain $g$ itself. Thus, we can convert a function $f(x)$ from the space to be a function on the quotient space $f(g)$ via what we call **lifting**:\n",
"\n",
"\\begin{equation}\n",
"f\\uparrow^G(g) = f(gx_0)\n",
"\\end{equation}\n",
"\n",
"All that discussion and thinking for such a simple equation. One point to note is that you can plug any element $g$ from the group into $f\\uparrow^G(g)$ but it is bijective only over $G / H$ (the cosets). Your null space will be the whole subgroup $H_0$.\n",
"\n",
"```{margin}\n",
"A coset can have multiple labels in this system. $g_1H_0$ and $g_2H_0$ could be the same coset. There are no consequences of this, but just be aware. \n",
"```\n",
"\n",
"Going the opposite, from a function on the group to the space, is called **projecting** because it will have a smaller domain. We can use the same process as above. We create the quotient space and then just take the average over a single coset to get a single value for the point $x$:\n",
"\n",
"\\begin{equation}\n",
"f\\downarrow_\\mathcal{X}(x) = \\frac{1}{|H_0|}\\sum_{u \\in gH_0} f(u), \\: gx_0 = x\n",
"\\end{equation}\n",
"\n",
"where we've used the fact that $|gH_0| = |H_0|$. Note that the coset generating element $g$ is found by solving $gx_0 = x$, where of course $g$ is not a stabilizing element (otherwise $gx_0 = x_0$ by definition). Let's see some examples now to make all of these easier to understand. \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{tabbed} ⬡ Finite Group $Z_6$ \n",
"\n",
"Our function is the color of the vertices in our picture {glue:}`hex-0` $f(x) = (r, g, b)$ where $r,g,b$ are fractions of the color red, blue green. If we define the vertices to start at the line pointing up, we can label them $0,\\ldots,5$. So for example $f(0) =(0.11, 0.74, 0.61)$, which is the color of the top vertex. \n",
"\n",
"We can define the origin as $x_0 = 0$. $|G| = |\\mathcal{X}|$ for this finite group and thus our stabilizer subgroup only contains the identity $H_0 = \\{e\\}$. Our cosets and their associated points will be $(eH_0, x = 0), (rH_0, x = 1), (r^2H_0, x = 2), (r^3H_0, x = 3), (r^4H_0, x = 4), (r^5H_0, x = 5)$. The lifted $f\\uparrow^G(g)$ can be easily defined using these cosets. \n",
"\n",
"```\n",
"\n",
"```{tabbed} ▩ Locally Compact p4m\n",
"p4m is intended for images, so our example will be a function $f: \\mathbb{R}^2 \\rightarrow \\mathbb{R}^3$ that represents a color image. This group contains rotations about the origin, so if we choose the origin as our stabilizer it will cleanly separate our group. Namely:\n",
"\n",
"$$\n",
"H_0 = \\left\\{e_ne_r, e_nr, e_nr^2, e_nr^3 , e_ns , e_nrs , e_nr^2s , e_nr^3s\\right\\}\n",
"$$\n",
"\n",
"where our elements have been written out as the semidirect product of translations and $D_4$ as discussed previously. Let's compute a coset to get a sense of this process. Consider the group element $t_{1,0}e_r$ creating the coset $t_{1,0}e_rH_0$. The first element of the coset is $t_{1,0}e_r \\cdot e_ne_r = t_{1,0}e_r$. The second element is $t_{1,0}e_r \\cdot t_{0,0}r = t_{1,0}r$. The rest of the elements of this coset are:\n",
"\n",
"$$\n",
"t_{1,0}e_rH_0 = \\left\\{t_{1,0}e_r , t_{1,0}r , t_{1,0}r^2 ,t_{1,0}r^3 , t_{1,0}s , t_{1,0}rs , t_{1,0}r^2s , t_{1,0}r^3s\\right\\}\n",
"$$\n",
"\n",
"Note these were simple to compute because $\\phi(g)(e_n) = ge_ng^{-1} = e_n$. Now what point is this associated with? Consider the first non-identity coset element $t_{1,0}r$ acting on the origin: $(0,0)\\rightarrow(0,0)\\rightarrow(1,0)$. You'll see similarly that all elements in the coset will follow the same pattern: the first element from $H_0$ doesn't move the origin (by definition) and the second element is the same in the coset (translation by $x + 1$). Thus, the first coset $t_{1,0}e_rH_0$ is associated with the point $(1,0)$.\n",
"\n",
"Now consider a coset that involves a $D_4$ element: $t_{1,0}rsH_0$. You can compute its elements as:\n",
"\n",
"$$\n",
"t_{1,0}rsH_0 = \\left\\{t_{1,0}rs, t_{1,0}s, t_{1,0}r^3s, t_{1,0}r^2s, t_{1,0}r, t_{1,0}e_r , t_{1,0}r^3 , t_{1,0}r^2\\\\\\right\\}\n",
"$$\n",
"\n",
"This contains all the same elements as the coset $t_{1,0}e_rH_0$! This is because we have more group elements than space in $\\mathcal{X}$; multiple $g$'s result in the same coset. This doesn't change our intuition though: the translation transform still defines the connection between our coset and the space. Our lifting function will be \n",
"\n",
"$$\n",
"f\\uparrow^G(g) = f\\uparrow^G\\left((t_{x,y}, h)\\right) = f(x,y)\n",
"$$\n",
"\n",
"\n",
"```\n",
"\n",
"```{tabbed} ⚽ SO(3) Lie Group\n",
"\n",
"For this example, our function will be points on the sphere $f(x) = \\sum_i \\delta(x - x_i)$. We can represent the group element rotations (among other choices) as being the product of three rotations about the $y$ and $z$ axes: $R_z(\\alpha)R_y(\\beta)R_z(\\gamma)$ If that seems surprising, remember that rotations are not commutative. Santa lives in the north pole, so let's choose the north pole $(0, 0, 1)$ as our stabilizer. You cannot choose $(0,0,0)$ remember because it is not in the space. Our subgroup is rotations that only involve $\\gamma$, for example $R_z(0)R_y(0)R_z(90)$ is in our subgroup $H_0$. Let's generate a coset, say for the group element $g = R_z(120)R_y(0)R_z(60)$. The coset $gH_0$ will be rotations of $R_z(120)R_y(0)R_z(60)R_z(0)R_y(0)R_z(\\gamma)$, which can be simplified to $R_z(120)R_y(0)R_z(60 + \\gamma)$. Thus the coset is $gH_0 = \\left\\{R_z(120)R_y(0)R_z(60 + \\gamma)\\, \\forall \\gamma \\in [0, 2\\pi]\\right\\}$ \n",
" \n",
"Now what point is associated with this coset? It will be this rotation applied to the origin: $R_z(120)R_y(0)R_z(60 + \\gamma)x_0$. The first rotation has no effect, by definition, so it becomes $R_z(120)R_y(0)x_0$. The general form is that the coset for a point $x$ is the rotation such that $R_z(\\alpha)R_y(\\beta)x_0 = x$. This quotient space happens to be identical to SO(2), rotations in 2D, because it's defined by two angles. The lifting functions is defined as:\n",
"\n",
"\n",
"$$\n",
"f\\uparrow^G(g) = f\\uparrow^G\\left(R_z(\\alpha)R_y(\\beta)R_z(\\gamma)\\right) = f\\left(R_z(\\alpha)R_y(\\beta)x_0\\right)\n",
"$$\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## G-Equivariant Convolutions on Finite Groups\n",
"\n",
"We now have all the tools to build an equivariant network for a finite group. We'll continue with our example group $Z_6$ on vertices of a hexagon. The cells below does our imports.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import dmol\n",
"from dmol import color_cycle"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start by defining our input function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make our colors (nothing to do with the model)\n",
"\n",
"vertex_colors = []\n",
"for c in color_cycle:\n",
" hex_color = int(c[1:], 16)\n",
" r = hex_color // 256**2\n",
" hex_color = hex_color - r * 256**2\n",
" g = hex_color // 256\n",
" hex_color = hex_color - g * 256\n",
" b = hex_color\n",
" vertex_colors.append((r / 256, g / 256, b / 256))\n",
"vertex_colors = np.array(vertex_colors)\n",
"\n",
"\n",
"def z6_fxn(x):\n",
" return vertex_colors[x]\n",
"\n",
"\n",
"z6_fxn(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we assume our group is indexed already by our vertex coordinates $\\{0,\\ldots, 5\\}$ then our function is already defined on the group. Now we need our trainable kernel function. It will be defined like our other function.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make weights be 3x3 matrices at each group element\n",
"# 3x3 so that we have 3 color channels in and 3 out\n",
"weights = np.random.normal(size=(6, 3, 3))\n",
"\n",
"\n",
"def z6_omega(x):\n",
" return weights[x]\n",
"\n",
"\n",
"z6_omega(3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can define our group convolution operator from Equation 8.6. We do need one helper function to get an inverse group element. Remember too that this returns a *function*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def z6_inv(g):\n",
" return (6 - g) % 6\n",
"\n",
"\n",
"def z6_prod(g1, g2):\n",
" return (g1 + g2) % 6\n",
"\n",
"\n",
"def conv(f, p):\n",
" def out(u):\n",
" g = np.arange(6)\n",
" # einsum is so we can do matrix product for elements of f and g,\n",
" # since we have one matrix per color\n",
" c = np.sum(np.einsum(\"ij,ijk->ik\", f(z6_prod(u, z6_inv(g))), p(g)), axis=0)\n",
" return c\n",
"\n",
" return out\n",
"\n",
"\n",
"conv(z6_fxn, z6_omega)(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At this point, we can now verify that the CNN is equivariant by comparing transforming the input function and the output function. We'll need to define our function transforms as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def z6_fxn_trans(g, f):\n",
" return lambda h: f(z6_prod(z6_inv(g), h))\n",
"\n",
"\n",
"z6_fxn(0), z6_fxn_trans(2, z6_fxn)(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we'll compute $\\psi\\left[\\mathbb{T}_2 f(x)\\right]$ -- the network acting on the transformed input function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trans_element = 2\n",
"trans_input_fxn = z6_fxn_trans(trans_element, z6_fxn)\n",
"trans_input_out = conv(trans_input_fxn, z6_omega)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we compute $\\mathbb{T}_2\\psi\\left[f(x)\\right]$ -- the transform acting on the network output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output_fxn = conv(z6_fxn, z6_omega)\n",
"trans_output_out = z6_fxn_trans(trans_element, output_fxn)\n",
"\n",
"print(\"g -> psi[f(g)], g -> psi[Tgf(g)], g-> Tg psi[f(g)]\")\n",
"for i in range(6):\n",
" print(\n",
" i,\n",
" np.round(conv(z6_fxn, z6_omega)(i), 2),\n",
" np.round(trans_input_out(i), 2),\n",
" np.round(trans_output_out(i), 2),\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the outputs indeed match and therefore our network is G-equivariant. One last detail is that it would be nice to visualize this, so we can add a nonlinearity to remap our output back to color space. Our colors should be between 0 and 1, so we can use a sigmoid to put the activations back to valid colors. I'll hide the input since it contains irrelevant code, but here is the visualization of the previous numbers showing the equivariance."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"c1 = conv(z6_fxn, z6_omega)\n",
"c2 = trans_input_out\n",
"c3 = trans_output_out\n",
"titles = [\n",
" r\"$\\psi\\left[f(g)\\right]$\",\n",
" r\"$\\psi\\left[\\mathbb{T}_2f(g)\\right]$\",\n",
" r\"$\\mathbb{T}_2\\psi\\left[f(g)\\right]$\",\n",
"]\n",
"\n",
"\n",
"def sigmoid(x):\n",
" return 1 / (1 + np.exp(-x))\n",
"\n",
"\n",
"def convert_color(r, g, b):\n",
" h = int(sigmoid(r) * 256**3 + sigmoid(g) * 256**2 + sigmoid(b) * 256)\n",
" return \"#{:6X}\".format(h)\n",
"\n",
"\n",
"c1 = [sigmoid(c1(i)) for i in range(6)]\n",
"c2 = [sigmoid(c2(i)) for i in range(6)]\n",
"c3 = [sigmoid(c3(i)) for i in range(6)]\n",
"\n",
"fig, axs = plt.subplots(1, 3, squeeze=True)\n",
"points = np.array(\n",
" [\n",
" (0, 1),\n",
" (0.5 * np.sqrt(3), 0.5),\n",
" (0.5 * np.sqrt(3), -0.5),\n",
" (0, -1),\n",
" (-0.5 * np.sqrt(3), -0.5),\n",
" (-0.5 * np.sqrt(3), 0.5),\n",
" ]\n",
")\n",
"for i in range(3):\n",
" axs[i].scatter(points[:, 0], points[:, 1], color=[c1, c2, c3][i])\n",
" # plt.plot([0, points[0,0]], [0, points[0, 1]], color='black', zorder=0)\n",
" axs[i].set_xticks([])\n",
" axs[i].set_yticks([])\n",
" axs[i].set_xlim(-1.4, 1.4)\n",
" axs[i].set_ylim(-1.4, 1.4)\n",
" axs[i].set_aspect(\"equal\")\n",
" axs[i].set_title(titles[i], fontsize=8)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, our output looks the same if we apply the rotation either before or after, so our network is G-equivariant."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## G-Equivariant Convolutions with Translation\n",
"\n",
"How can we treat the p4m group? We cannot directly use the continuous convolution definition because the rotations/mirror subgroup is finite and we cannot use the finite convolution because the translation subgroup is locally compact (infinitely many elements). Instead, we will exploit the structure of the group: it is constructed via a semidirect product so each group element is a pair of elements. Namely we can rewrite Equation 8.6 using the constituent subgroups $N \\rtimes H$ and writing elements $g = hn, g^{-1} = n^{-1}h^{-1}$.\n",
"\n",
"```{margin}\n",
"Remember that $g = nh$ is fine to use because $(n, e_r)\\cdot (e_n, h) = (n, h)$, whereas the reverse requires using the conjugation $\\phi(h)(n)$.\n",
"```\n",
"\n",
"\\begin{equation}\n",
"(f * \\omega)(u) = \\sum_{n \\in N}\\sum_{h \\in H} f\\uparrow^G\\left(un^{-1}h^{-1}\\right)\\omega(hn)\n",
"\\end{equation}\n",
"\n",
"Now we must treat the fact that there are an infinite number of elements in $N$ (the translations). We can simply choose the kernel function ($\\omega$) to only have support ($\\omega(g) > 0$) at locations we want and that will simplify the integration. This may seem ad-hoc -- but remember we already made choices like not including 45° rotations. There do exist ways to systematically treat how to narrow the kernels into \"neigbhorhoods\" of groups in {cite}`finzi2020generalizing` or you can find a rigorous derivation specifically for p4 in {cite}`romero2020attentive` or {cite}`cohen2016group`.\n",
"\n",
"```{margin}\n",
"I have a hidden cell below which does a bit of magic. It makes the group elements be hashable. That in turn allows me to cache functions, enabling much faster speeds. This code would be unusable otherwise due to all the nested loops.\n",
"```\n",
"\n",
"Our goal for the p4m group is image data, so we'll limit the support of the kernel to only integer translations (like pixels) and limit the distance to 5 units. This simply reduces our sum over the normal subgroup ($N$). We can now begin our implementation. We'll start by loading an image which will serve as our function. It is a $32\\times32$ RGB image. Remember that we need to allow points to have 3 dimensions, where the third dimension is always 1 to accommodate our augmented space."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"# From https://gist.github.com/Susensio/61f4fee01150caaac1e10fc5f005eb75\n",
"\n",
"from functools import lru_cache, wraps\n",
"\n",
"\n",
"def np_cache(*args, **kwargs):\n",
" \"\"\"LRU cache implementation for functions whose FIRST parameter is a numpy array\n",
" >>> array = np.array([[1, 2, 3], [4, 5, 6]])\n",
" >>> @np_cache(maxsize=256)\n",
" ... def multiply(array, factor):\n",
" ... print(\"Calculating...\")\n",
" ... return factor*array\n",
" >>> multiply(array, 2)\n",
" Calculating...\n",
" array([[ 2, 4, 6],\n",
" [ 8, 10, 12]])\n",
" >>> multiply(array, 2)\n",
" array([[ 2, 4, 6],\n",
" [ 8, 10, 12]])\n",
" >>> multiply.cache_info()\n",
" CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)\n",
"\n",
" \"\"\"\n",
"\n",
" def decorator(function):\n",
" @wraps(function)\n",
" def wrapper(np_array, *args, **kwargs):\n",
" hashable_array = array_to_tuple(np_array)\n",
" return cached_wrapper(hashable_array, *args, **kwargs)\n",
"\n",
" @lru_cache(*args, **kwargs)\n",
" def cached_wrapper(hashable_array, *args, **kwargs):\n",
" array = np.array(hashable_array)\n",
" return function(array, *args, **kwargs)\n",
"\n",
" def array_to_tuple(np_array):\n",
" \"\"\"Iterates recursivelly.\"\"\"\n",
" try:\n",
" return tuple(array_to_tuple(_) for _ in np_array)\n",
" except TypeError:\n",
" return np_array\n",
"\n",
" # copy lru_cache attributes over too\n",
" wrapper.cache_info = cached_wrapper.cache_info\n",
" wrapper.cache_clear = cached_wrapper.cache_clear\n",
"\n",
" return wrapper\n",
"\n",
" return decorator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load image and drop alpha channel\n",
"W = 32\n",
"try:\n",
" func_vals = plt.imread(\"quadimg.png\")[..., :3]\n",
"except FileNotFoundError as e:\n",
" # maybe on google colab\n",
" import urllib.request\n",
"\n",
" urllib.request.urlretrieve(\n",
" \"https://raw.githubusercontent.com/whitead/dmol-book/master/dl/quadimg.png\",\n",
" \"quadimg.png\",\n",
" )\n",
" func_vals = plt.imread(\"quadimg.png\")[..., :3]\n",
"# we pad it with zeros to show boundary\n",
"func_vals = np.pad(\n",
" func_vals, ((1, 1), (1, 1), (0, 0)), mode=\"constant\", constant_values=0.2\n",
")\n",
"\n",
"\n",
"def pix_func(x):\n",
" # clip & squeeze & round to account for transformed values\n",
" xclip = np.squeeze(np.clip(np.round(x), -W // 2 - 1, W // 2)).astype(int)\n",
" # points are centered, fix that\n",
" xclip += [W // 2, W // 2, 0]\n",
" # add 1 to account for padding\n",
" return func_vals[xclip[..., 0] + 1, xclip[..., 1] + 1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_func(f, ax=None):\n",
" if ax is None:\n",
" plt.figure(figsize=(2, 2))\n",
" ax = plt.gca()\n",
" gridx, gridy = np.meshgrid(\n",
" np.arange(-W // 2, W // 2), np.arange(-W // 2, W // 2), indexing=\"ij\"\n",
" )\n",
" # make it into batched x,y indices and add dummy 1 indices for augmented space\n",
" batched_idx = np.vstack(\n",
" (gridx.flatten(), gridy.flatten(), np.ones_like(gridx.flatten()))\n",
" ).T\n",
" ax.imshow(f(batched_idx).reshape(W, W, 3), origin=\"upper\")\n",
" ax.axis(\"off\")\n",
"\n",
"\n",
"plot_func(pix_func)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's define our G-function transform so that we can transform our function with group elements. We'll apply a $rst_{12,-8}$ element to our function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def make_h(rot, mirror):\n",
" \"\"\"Make h subgroup element\"\"\"\n",
" m = np.eye(3)\n",
" if mirror:\n",
" m = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])\n",
" r = np.array(\n",
" [[np.cos(rot), -np.sin(rot), 0], [np.sin(rot), np.cos(rot), 0], [0, 0, 1]]\n",
" )\n",
" return r @ m\n",
"\n",
"\n",
"def make_n(dx, dy):\n",
" \"\"\"Make normal subgroup element\"\"\"\n",
" return np.array([[1, 0, dx], [0, 1, dy], [0, 0, 1]])\n",
"\n",
"\n",
"def g_func_trans(g, f):\n",
" \"\"\"compute g-function transform\"\"\"\n",
"\n",
" @np_cache(maxsize=W**3)\n",
" def fxn(x, g=g, f=f):\n",
" ginv = np.linalg.inv(g)\n",
" return f(ginv.reshape(1, 3, 3) @ x.reshape(-1, 3, 1))\n",
"\n",
" return fxn\n",
"\n",
"\n",
"g = make_h(np.pi, 1) @ make_n(12, -8)\n",
"tfunc = g_func_trans(g, pix_func)\n",
"plot_func(tfunc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we need to create our lifting and projecting maps to go from functions over points to functions over group elements. Remember, our lifting function just takes the translation element and makes that point. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# enumeraet stabilizer subgrou (rotation/mirrors)\n",
"stabilizer = []\n",
"for i in range(4):\n",
" for j in range(2):\n",
" stabilizer.append(make_h(i * np.pi / 2, j))\n",
"\n",
"\n",
"def lift(f):\n",
" \"\"\"lift f into group\"\"\"\n",
" # create new function from original\n",
" # that is f(gx_0)\n",
" @np_cache(maxsize=W**3)\n",
" def fxn(g, f=f):\n",
" return f(g @ np.array([0, 0, 1]))\n",
"\n",
" return fxn\n",
"\n",
"\n",
"def project(f):\n",
" \"\"\"create projected function over space\"\"\"\n",
"\n",
" @np_cache(maxsize=W**3)\n",
" def fxn(x, f=f):\n",
" # x may be batched so we have to allow it to be N x 3\n",
" x = np.array(x).reshape((-1, 3))\n",
" out = np.zeros((x.shape[0], 3))\n",
" for i, xi in enumerate(x):\n",
" # find coset gH\n",
" g = make_n(xi[0], xi[1])\n",
" # loop over coset\n",
" for h in stabilizer:\n",
" ghi = g @ h\n",
" out[i] += f(ghi)\n",
" out[i] /= len(stabilizer)\n",
" return out\n",
"\n",
" return fxn\n",
"\n",
"\n",
"# try them out\n",
"print(\"lifted\", lift(pix_func)(g))\n",
"print(\"projected\", project(lift(pix_func))([12, -8, 0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to create our kernel functions $\\phi$. Rather than make a function of the group elements, we'll use indices to represent the different group elements. Remember we need to apply a sigmoid at the end so that we stay in color space. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"kernel_width = 5 # must be odd\n",
"# make some random values for kernel (untrained)\n",
"# kernel is group elements x 3 x 3. The group elements are structured (for simplicity) as a N x 5 x 5\n",
"# the 3 x 3 part is because we have 3 color channels coming in and 3 going out.\n",
"kernel = np.random.uniform(\n",
" -0.5, 0.5, size=(len(stabilizer), kernel_width, kernel_width, 3, 3)\n",
")\n",
"\n",
"\n",
"def conv(f, p=kernel):\n",
" @np_cache(maxsize=W**4)\n",
" def fxn(u):\n",
" # It is possible to do this without inner for\n",
" # loops over convolution (use a standard conv),\n",
" # but we do this for simplicity.\n",
" result = 0\n",
" for hi, h in enumerate(stabilizer):\n",
" for nix in range(-kernel_width // 2, kernel_width // 2 + 1):\n",
" for niy in range(-kernel_width // 2, kernel_width // 2 + 1):\n",
" result += (\n",
" f(u @ make_n(-nix, -niy) @ np.linalg.inv(h))\n",
" @ kernel[hi, nix + kernel_width // 2, niy + kernel_width // 2]\n",
" )\n",
" return sigmoid(result)\n",
"\n",
" return fxn\n",
"\n",
"\n",
"# compute convolution\n",
"cout = conv(lift(pix_func))\n",
"# try it out an a group element\n",
"cout(g)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"At this point our convolution layer has returned a function over all group elements. We can visualize this by viewing each stabilizer element individually across the normal subgroup. This is like plotting each coset with a choice of representative element."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_coset(h, f, ax):\n",
" \"\"\"plot a function over group elements on cosets given representative g\"\"\"\n",
" gridx, gridy = np.meshgrid(\n",
" np.arange(-W // 2, W // 2), np.arange(-W // 2, W // 2), indexing=\"ij\"\n",
" )\n",
" # make it into batched x,y indices and add dummy 1 indices for augmented space\n",
" batched_idx = np.vstack(\n",
" (gridx.flatten(), gridy.flatten(), np.ones_like(gridx.flatten()))\n",
" ).T\n",
" values = np.zeros((W**2, 3))\n",
" for i, bi in enumerate(batched_idx):\n",
" values[i] = f(h @ make_n(bi[0], bi[1]))\n",
" ax.imshow(values.reshape(W, W, 3), origin=\"upper\")\n",
" ax.axis(\"off\")\n",
"\n",
"\n",
"# try it with mirror\n",
"plt.figure(figsize=(2, 2))\n",
"plot_coset(make_h(0, 1), lift(pix_func), ax=plt.gca())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we will plot our convolution for each possible coset representative. This code is *incredibly* inefficient because we have so many loops in plotting and the convolution. This is where the `np_cache` from above helps. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"stabilizer_names = [\"$e$\", \"$r$\", \"$r^2$\", \"$r^3$\", \"$s$\", \"$rs$\", \"$r^2s$\", \"$r^3s$\"]\n",
"fig, axs = plt.subplots(2, 4, figsize=(8, 4))\n",
"axs = axs.flatten()\n",
"for i, (n, h) in enumerate(zip(stabilizer_names, stabilizer)):\n",
" ax = axs[i]\n",
" plot_coset(h, cout, ax)\n",
" ax.set_title(n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These convolutions are untrained, so it's sort of a diffuse random combination of pixels. You can see each piece of the function broken out by stabilizer group element (the rotation/mirroring). We can stack multiple layers of these convolution if we wanted. At the end, we want to get back to our space with the projection. \n",
"Let us now show our layers are equivariant by applying a G-function transform to input and output. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(1, 3, squeeze=True)\n",
"plot_func(project(cout), ax=axs[0])\n",
"axs[0].set_title(r\"$\\psi\\left[f(g)\\right]$\")\n",
"\n",
"# make a transformation for visualization purposes\n",
"g = make_h(np.pi, 0) @ make_n(-10, 16)\n",
"tfunc = g_func_trans(g, project(cout))\n",
"plot_func(tfunc, ax=axs[1])\n",
"axs[1].set_title(r\"$\\mathbb{T}\\psi\\left[f(g)\\right]$\")\n",
"\n",
"tcout = project(conv(lift(g_func_trans(g, pix_func))))\n",
"\n",
"plot_func(tcout, ax=axs[2])\n",
"axs[2].set_title(r\"$\\psi\\left[\\mathbb{T}f(g)\\right]$\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This shows that the convolution layer is indeed equivariant. Details not covered here are how to do pooling (if desired) and the choice of nonlinearity. You can find more details on this for the p4m group in Cohen et al. {cite}`cohen2016group`. This implementation is also quite slow! Kondor et al. {cite}`kondor2018generalization` show how you can reduce the number of operations by identifying sparsity in the convolutions. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Group Representation\n",
"\n",
"p4m was an infinite group but we restricted ourselves to a finite subset. Before we can progress to truly infinite locally compact groups, like SO(3), we need to learn how to systematically represent the group element binary operation. You can find a detailed description of representation theory in Serre {cite}`serre1977linear` and it is covered in Zee {cite}`zee2016`. Thus far, we've discussed the group actions -- how they affect a point. Now we need to describe how to represent them with matrices. This will be a very quick overview of this topic, but representation of groups is a large area with well-established references. There is specifically a great amount of literature about building up these representations, but we'll try to focus on using them since you generally can look-up the representations for most groups we'll operate in. \n",
"\n",
"Let us first define a representation on a group:\n",
"\n",
"```{admonition} Linear Representation of a Group\n",
"Let $G$ be a group on an $n$-dimensional vector space $\\mathcal{X}$ . A linear representation of $G$ is a group homomorphism: $\\rho: G \\rightarrow GL(m,\\mathbb{C})$ where $GL(m, \\mathbb{C})$ is the space of $m\\times m$ square invertible matrices with complex numbers. The representation $\\rho$ should satisfy the following equation\n",
"\n",
"\\begin{equation}\n",
"\\label{rep-def}\n",
"\\rho\\left(g_1\\cdot g_2\\right) = \\rho\\left(g_1\\right) \\rho\\left(g_2\\right)\\; \\forall\\, g_1,g_2 \\in G\n",
"\\end{equation}\n",
"\n",
"where the term $\\rho\\left(g_1\\right) \\rho\\left(g_2\\right)$ is a matrix product. \n",
"```\n",
"\n",
"There are a few things to note about this definition. First, the representation assigns matrices to group elements in such a way that multiplying matrices gives the same representation as getting the representation of the binary operation ($\\rho\\left(g_1\\cdot g_2\\right)$). Second, the matrices have to be square and invertible. This follows from the requirement that group elements must have an inverse, so naturally we need invertible matrices. The invertible requirement also means we often need to allow complex numbers. Third, the **degree** of the representation ($m$) need not be the same size as the vector space. \n",
"\n",
"There is a big detail missing from this definition. Does this have anything to do with how the group element affect a point? No. Consider that $\\rho(g_i) = 1$ is a valid representation, as in it satisfies the definition. Yet $1$ is not the correct way to transforms points with group elements. If we go further and say that the representation is *injective* (one to one), then we must have a unique representation for every group element. That is called a **faithful representation**. This is better, but it turns out there are still multiple faithful representations for a group. \n",
"\n",
"Remember the way a group affects a point is a **group action**, which maps from the direct product of $G, \\mathcal{X}$ (i.e., a tuple like $(g_2, x)$ to $\\mathcal{X}$). A group action, if it is linear, can also be a representation. Consider that we write the group action $\\pi$ (how a group element affects a point) as $\\pi(g)(x) = x'$. You can convert this into a square matrix in $\\mathcal{X}\\times\\mathcal{X}$ by considering how each element of $x'$ is affected the element in $x$. This matrix can be further shown to be in $GL(m, \\mathcal{X})$ and a representation by relying on its linearity. There isn't a special word for this, but often groups are defined in terms of these special matrices that both transforms points and are valid representations (e.g., SO(3)). They are then called the **defining representation** or **fundamental representation**. \n",
"\n",
"\n",
"Let's now see group representations on the examples above that are both group actions and representations. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{tabbed} ⬡ Finite Group $Z_6$ \n",
"\n",
"Our group action defined above was modular arithmetic, which is not linear and so we cannot use it to construct representation. There are multiple representation for cyclic groups like $Z_6$. If you're comfortable with complex numbers, you can build circulant matrices of $6$th roots of unity. If that confuses you, like it does me, then you can also just view this group like a rotation group. Just like how if you rotate enough times you get back to the beginning, you can also use rotation matrices of $360 / 6 = 60^{\\circ}$. This requires a 2D vector representation though for the space. With this choice, a representation is:\n",
"\n",
"$$\n",
"\\left[\\begin{array}{lr}\n",
"\\cos\\frac{k2\\pi}{6} & -\\sin\\frac{k2\\pi}{6}\\\\\n",
"\\sin\\frac{k2\\pi}{6} & \\cos\\frac{k2\\pi}{6}\\\\\n",
"\\end{array}\\right],\\, k \\in \\left\\{0, 1, 2, 3, 4, 5\\right\\}\n",
"$$\n",
"\n",
"Let's verify that this is a representation by checking that $r^2\\cdot\\,r^4 = e$\n",
"\n",
"$$\n",
"\\left[\\begin{array}{lr}\n",
"\\cos\\frac{4\\pi}{6} & -\\sin\\frac{4\\pi}{6}\\\\\n",
"\\sin\\frac{4\\pi}{6} & \\cos\\frac{4\\pi}{6}\\\\\n",
"\\end{array}\\right]\\left[\\begin{array}{lr}\n",
"\\cos\\frac{8\\pi}{6} & -\\sin\\frac{k2\\pi}{6}\\\\\n",
"\\sin\\frac{8\\pi}{6} & \\cos\\frac{8\\pi}{6}\\\\\n",
"\\end{array}\\right] = \\left[\\begin{array}{lr}\n",
"\\cos\\frac{12\\pi}{6} & -\\sin\\frac{12\\pi}{6}\\\\\n",
"\\sin\\frac{12\\pi}{6} & \\cos\\frac{12\\pi}{6}\\\\\n",
"\\end{array}\\right] = \\left[\\begin{array}{lr}\n",
"1 & 0\\\\\n",
"0 & 1\\\\\n",
"\\end{array}\\right]\n",
"$$\n",
"\n",
"You can also verify that this is a group action by repeated application to the point $(1,0)$, which will rotate around the unit circle. \n",
"\n",
"```\n",
"\n",
"```{tabbed} ▩ Locally Compact p4m\n",
"\n",
"Our group action defined above for the translation elements is not linear. To define a representation, we can use [**Affine Matrices**]() which are $3\\times3$ invertible square matrices. That means even though our goal is 2D data, we need to introduce a 3rd dimension: $(x, y, 1)$. The 3rd dimension is always $1$ and is called the augmented dimension. To specify a group representation we simply need to multiply an affine matrix for rotation, reflection, and translation (*in that order!*). These are:\n",
"\n",
"Rotation:\n",
"\n",
"$$\n",
"\\left[\\begin{array}{lcr}\n",
"\\cos\\frac{k\\pi}{4} & -\\sin\\frac{k\\pi}{4} & 0\\\\\n",
"\\sin\\frac{k\\pi}{4} & \\cos\\frac{k\\pi}{4} & 0\\\\\n",
"0 & 0 & 1\\\\\n",
"\\end{array}\\right] ,\\, k \\in \\left\\{0, 1, 2, 3\\right\\}\n",
"$$\n",
"\n",
"Reflection:\n",
"\n",
"$$\n",
"\\left[\\begin{array}{lcr}\n",
"1 & 0 & 0\\\\\n",
"0 & -1 & 0\\\\\n",
"0 & 0 & 1\\\\\n",
"\\end{array}\\right]\n",
"$$\n",
"\n",
"Translation:\n",
"\n",
"$$\n",
"\\left[\\begin{array}{lcr}\n",
"1 & 0 & \\Delta x\\\\\n",
"0 & 1 & \\Delta y\\\\\n",
"0 & 0 & 1\\\\\n",
"\\end{array}\\right]\n",
"$$\n",
"\n",
"It is a bit more involved to verify this is a group representation, but you can try a few group element products to convince yourself. Do not forget the special homomorphism (conjugate $\\phi(h)(n)$) for semidirect products when multiplying group element, which ensures the correct behavior if rearrange the order of the matrices. \n",
"\n",
"```\n",
"\n",
"\n",
"```{tabbed} ⚽ SO(3) Group\n",
"\n",
"A representation of the SO(3) group is just its usual group action: the product of 3 3D rotation matrices $R_z(\\alpha)R_y(\\beta)R_z(\\gamma)$ where $\\alpha, \\gamma \\in [0, 2\\pi], \\beta \\in [0, \\pi]$ and the matrices are defined above.\n",
"\n",
"```\n",
"\n",
"### Unitary Representations\n",
"\n",
"One minor detail is that if we have some representation $\\rho(g_1)\\rho(g_2) = \\rho(g_1\\cdot g_2)$, then we could make a \"new\" representation $\\rho'$ by inserting some invertible matrix $\\mathbf{S}$:\n",
"\n",
"$$\n",
"\\rho'(g) = \\mathbf{S}^{-1}\\rho(g)\\mathbf{S}\n",
"$$\n",
"\n",
"because \n",
"\n",
"$$\n",
"\\rho'(g_1)\\rho'(g_2) = \\mathbf{S}^{-1}\\rho(g_1)\\mathbf{S}\\mathbf{S}^{-1}\\rho(g_2)\\mathbf{S}\n",
"$$\n",
"\n",
"$$\n",
"= \\mathbf{S}^{-1}\\rho(g_1)\\rho(g_2)\\mathbf{S} = \\rho'(g_1\\cdot g_2)\n",
"$$\n",
"\n",
"There is a theorem, the Unitarity Theorem, that says we can always choose an $\\mathbf{S}$ (for finite groups) such that we make our representation **unitary**. Unitary means that $\\rho(g)^{-1} = \\rho^{\\dagger}(g)$ for any $g$. Remember that $\\rho(g)$ is a matrix, so $\\rho^{\\dagger}(g)$ is the adjoint (transpose and complex conjugate) of the matrix. Thus, without any loss of generality we can assume all representations we use are unitary or can be trivially converted to unitary. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Irreducible representations\n",
"\n",
"These representations that both describe the group action and how group elements affect on another are typically **reducible**, meaning if you drop the requirement that they also describe group action they can be simplified. The process of reducing representations is again a topic better explored in other references {cite}`serre1977linear`, but here I will sketch out the important ideas. The main idea is that we can form decomposable unitary representation matrices that are composed of smaller block matrices and zero blocks. These smaller blocks, $\\rho_i(g)$, are *irreducible* --- they cannot be broken into smaller blocks and zeros\n",
"\n",
"\\begin{equation}\n",
" \\rho(g) = \\mathbf{S}^{-1} \\begin{pmatrix} \n",
"\\rho_0(g) & 0 & \\cdots & 0 \\\\\n",
"0 & \\rho_1(g) & \\cdots & 0 \\\\\n",
"\\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
"0 & 0 & \\cdots & \\rho_k(g) \\\\\n",
"\\end{pmatrix} \\mathbf{S}\n",
"\\end{equation}\n",
"\n",
"This block notation is consistent, regardless of $g$. That is a strong statement because $\\rho(g_1)\\rho(g_2)$ should give back an element in $\\mathcal{G}$ --- $\\rho(g')$ --- with the same block structure. What is interesting about this notation is that each block is then *itself* a representation. We could just pick $\\rho_0(g)$ as a representation, and if this block structure is true for all $g$, then $\\rho_0(g_1)\\rho_0(g_2)$ should give back something with non-zero elements only in the rows/columns of the $\\rho_0(g)$ block. We could also combine $\\rho_0(g)$ and $\\rho_1(g)$ or even $\\rho_0(g)$ and $\\rho_2(g)$. Thus, these irreducible representations (**irreps**) are the pieces that we use to build any other representation. The irreducible representations are all dimension 1 if $\\mathcal{G}$ is abelian, but otherwise irreducible representations are square matrices. \n",
"\n",
"To add some notation, we use **direct sums** to write the bigger unitary representation:\n",
"\n",
"\\begin{equation}\n",
"\\rho(g) = \\mathbf{S}^{-1} \\begin{pmatrix} \n",
"\\rho_0(g) & 0 & \\cdots & 0 \\\\\n",
"0 & \\rho_1(g) & \\cdots & 0 \\\\\n",
"\\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
"0 & 0 & \\cdots & \\rho_k(g) \\\\\n",
"\\end{pmatrix} \\mathbf{S} = \\rho_0(g)\\oplus\\rho_1(g)\\oplus\\ldots\\oplus\\rho_k(g)\n",
"\\end{equation}\n",
"\n",
"and we could just stop the direct sum wherever we would like. The number of irreducible representations is finite for finite groups and infinite for locally compact groups. These irreducible representations are like orthonormal basis-functions or basis-vectors from Hilbert spaces. From the Peter-Weyl theorem, they specifically can be transformed to create a complete basis-set for integrable ($L^2$) functions of the group. \"Transformed\" because irreducible representations are representations of $g$ ( output a matrix), but we need them to output a scalar to be a basis-set for an integrable function. \n",
"\n",
"Where do we get these integrable functions? Recall we can use lifting to move functions of our space to our group and then these irreducible representations enable us to represent the functions as a (direct) sum of coefficients of the irreducible representations. Remember, each individual irreducible representation is itself a valid representation, but they are not all faithful and so you need some of them to uniquely represent all group elements and all of them to represent arbitrary functions over the group. One final note, these irreducible representations have been essentially mapped out for all groups and thus we look them up in table rather than try to construct them."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## G-Equivariant Convolutions on Compact Groups\n",
"\n",
"Now we can represent functions on groups as a direct sum (list of increasing length vectors) of coefficients on the irreps as\n",
"\n",
"$$\n",
"f(g) = f_0\\cdot\\rho_0(g)\\oplus\\vec{f}_1\\cdot\\rho_1(g)\\oplus\\ldots\\oplus\\vec{f}_k\\cdot\\rho_k(g)\n",
"$$ (fft)\n",
"\n",
"where the direct sum notation $\\oplus$ is just a shorthand that means put all this stuff into a big matrix of increasingly large blocks. The individual $\\vec{f}_i$s are called **fragments** to distinguish them from the actual irreps (which are functions).\n",
"\n",
"We could even more compactly write this as $f(g) = f_0\\oplus\\vec{f}_1\\oplus\\ldots\\oplus\\vec{f}_k$. We'd like to revise the G-Equivariant convolution layer equation:\n",
"\n",
"\\begin{equation}\n",
"\\psi(f) = (f * \\omega)(u) = \\int_G f\\uparrow^G\\left(ug^{-1}\\right)\\omega\\uparrow^G\\left(g\\right)\\,d\\mu(g)\n",
"\\end{equation}\n",
"\n",
"to use irreps now. *It turns out that the convolutional integral becomes a product of irreps.* This is just like how convolutions in Fourier space become products. {cite}`kondor2018generalization`. Our expression simplifies to:\n",
"\n",
"$$\n",
"\\psi(f) = f_0 w_0\\oplus \\vec{f}_1 w_1 \\oplus\\ldots\\oplus \\vec{f}_k w_k\n",
"$$ (compact-gequiv)\n",
"\n",
"This result says we just multiply the irreducible representations by weights, but do not mix across irreps. The weights become matrices if we start to allow multiple channels (multiple fragments). An important point then is how we actually can learn if there is no communication between irreps. That's where the nonlinearity comes in. It is discussed in more depth below, but the most common nonlinearity is to take a tensor product (all irreps times all irreps) and then reduce that by multiplying the larger rank tensor by a special tensor for the group called Clebsch-Gordan coefficients that reduces it equivariantly back down to the direct sum of irreps. This enables mixing between the irreps and is nonlinear."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Irreducible representations on SO(3)\n",
"\n",
"There is an infinite sequence of possible irreducible representations for the SO(3) group known as the Wigner D-matrices. They must be of odd dimension, and so are traditionally written as the sequence $2l + 1$ where $l$ is an integer that serves as the irreducible representation index. The Wigner D-matrices are square with dimension $2l + 1$ and are a function of the group element (e.g., the angles of the rotation). This may be surprising, that the irreducible representations can be of greater dimension than our reducible representation $R_z(\\alpha)R_y(\\beta)R_z(\\gamma)$. Of course, remember the matrix blocks we built above --- we can keep making these representations bigger. But do they have any intuitive meaning? One way to think about irreps larger than the fundamental representation is to consider SO(3) acting on 3 dimensional $n$th degree monomials rather than points: $x^iy^jz^k$ where $l = i + j + k$. The trivial representation works on the 0th degree monomial ($l = 0$), the $l = 1$ irrep has three possible monomials ($x, y, z$), the $l=2$ irrep has 5 possible monomials (excluding a redundant term) $x^2$, $y^2$, $z^2$, $xy$, $xz. You can find a nice description of [this here](https://math.stackexchange.com/a/40141). \n",
"\n",
"What $l$ should you choose? It depends on your input and output. If you choose $l = 0$, you can only represent scalars (not points/vectors). If you choose $l = 1$, you can represent vectors. You can always pick a larger $l$, but if you pick a lower $l$ you will become invariant to the higher-order geometric structure. \n",
"\n",
"```{margin}\n",
"The reason you just plug in the values into the spherical harmonics\n",
"is you're (1) turning your data into a dirac-delta function and (2) taking its\n",
"integral over all the irreps, which turns out to be the same as plugging in the value.\n",
"```\n",
"\n",
"Our choice of Euler Angles (zyz rotation) means that the Wigner D-matrices turn into spherical harmonics. Now how do we get our input data, into the irrep for the group SO(3)? You just plug the input coordinates/features into the spherical harmonics. Have multiple scalar features (e.g., charge, atomic number)? Simply add another axis to the irreps and create multiple \"channels\". Another detail is that our weight size seems to be set by the irrep size. How do we get wider layers (more weights)? The same way: by adding more channels to each irrep. \n",
"\n",
"\n",
"### SO(3) Nonlinearity & Mixing\n",
"\n",
"There are two equations for equivariant nonlinearity in SO(3), and they are sometimes combined. The first nonlinearity is a **Clebsch-Gordan tensor product** and enables mixing between irreps. The equation is\n",
"\n",
"$$\n",
"\\vec{f}_i' = \\sum_j\\sum_k \\mathrm{CG}_{j,k,i} \\cdot \\vec{f}_j\\vec{f}_k \n",
"$$ (cg-nl)\n",
"\n",
"where $\\mathrm{CG}_{j,k,i}$ are the Clebsch-Gordan coefficients that ensure we maintain equivariance after multiplying all irreps with all irreps (and do a change of basis). This expression is sometimes written as $\\mathrm{CG}_{j,k,i}\\,\\vec{f} \\otimes \\vec{f}$ \n",
"\n",
"As before, $\\vec{f}_i$ are the fragments (coefficients) on the irreps that represent our function $f(g)$ of the group. The fragments are usually computed directly by plugging in the coordinates into spherical harmonic equations. {cite}`kondor2018clebsch` showed that this is itself nonlinear, and thus a complete layer with nonlinearity would be that equation combined with Equation {eq}`compact-gequiv`. We may also choose to skip some of the terms, since this is an exepnsive equation. \n",
"\n",
"There is another kind of nonlinearity that is equivariant called gated nonlinearities {cite}`weiler20183d`. The equation is simple; just compute the magnitude of each of the irrep fragments $\\vec{f}_i$ and put it through a traditional neural network nonlinearity (e.g., ReLU):\n",
"\n",
"$$\n",
"\\sigma_{\\textrm{gated}}(\\vec{f}_i) = \\sigma\\left(|\\vec{f}_i|\\right)\\vec{f}_i\n",
"$$\n",
"\n",
"The gated nonlinearity is sometimes used instead of Equation {eq}`cg-nl` or as an extra step after it {cite}`thomas2018tensor`.\n",
"\n",
"At the end of the network, most of the time we simply take the $f_0$ scalar (the $l=0$ irrep fragment). We may have multiple channels, so we don't have one $f_0$. But, often we are doing classification or predicting energy (and its derivative, forces) and thus want a $l = 0$ feature. The Clebsch-Gordan tensor products are essential to ensure mixing between the spatial information at the higher dimensional irrep fragments and scalar input features like atomic number. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SO(3) Equivariant Example\n",
"\n",
"Let's implement a non-differentiable version of the equations above for the SO(3) group. To begin, let's write the code to convert our points into their irreps. Our code is not differentiable, so we won't be able to train.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from scipy.special import sph_harm\n",
"\n",
"\n",
"def cart2irreps(x, l):\n",
" # convert to spherical coords to eval\n",
" N = x.shape[0]\n",
" r = np.linalg.norm(x, axis=-1)\n",
" azimuth = np.arctan2(x[:, 1], x[:, 0])\n",
" polar = np.arccos(x[:, 2], r)\n",
" f = []\n",
" for li in range(l):\n",
" fi = []\n",
" for m in range(-li, li + 1):\n",
" y = sph_harm(m, li, azimuth, polar)\n",
" # convert to real\n",
" if m < 0:\n",
" y = np.sqrt(2) * (-1) ** m * y.imag\n",
" elif m > 0:\n",
" y = np.sqrt(2) * (-1) ** m * y.real\n",
" fi.append(y.real)\n",
" fi = np.array(fi)\n",
" f.append(fi.reshape(N, 2 * li + 1))\n",
" return f\n",
"\n",
"\n",
"def print_irreps(f):\n",
" for i in range(len(f)):\n",
" if len(f[0].shape) == 3:\n",
" print(f\"irrep {i} ({f[i].shape[-1]} channels)\")\n",
" else:\n",
" print(f\"irrep {i} (no channels)\")\n",
" print(f[i])\n",
"\n",
"\n",
"points = np.random.rand(2, 3)\n",
"# make them be on unit sphere\n",
"points /= np.linalg.norm(points, axis=-1)[:, np.newaxis]\n",
"M = 3 # number of irreps\n",
"f = cart2irreps(points, M)\n",
"print_irreps(f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We chose to use 3 irreps here and get a fragment vector for each particle at each irrep. This gives a scalar for $l=0$ irrep, a 3 dimensional vector for $l=1$, and a 5 dimensional vector for $l=2$ irrep. It's a choice, but usually you'll see networks encode a 3D point into the $l=1$ irrep because it's the smallest irrep that won't become invariant. This also allows a sort of separation of input features, where we can put scalar properties like mass or element into the $l = 0$ irrep and the point position into the $l=1$ irrep. Often, you'll also see multiple channels (multiple sets of fragments at a given irrep) to add expressiveness.\n",
"\n",
"Notice that $l=0$ is the same for both points - that's because the first spherical harmonic is constant. Another reason to put scalar quantities there!\n",
"\n",
"Let's now implement the linear part - Equation {eq}`compact-gequiv`. We'll have channels now, since otherwise we get a single weight."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def linear(f, W):\n",
" for l in range(len(f)):\n",
" # promote to have channels, if not yet\n",
" if len(f[l].shape) == 2:\n",
" f[l] = f[l][..., None]\n",
" f[l] = np.einsum(\"ijk,kl->ijl\", f[l], W[l])\n",
" return f\n",
"\n",
"\n",
"def init_weights(cin, cout):\n",
" return np.random.randn(M, cin, cout)\n",
"\n",
"\n",
"weights = init_weights(1, 4)\n",
"print(\"Input shapes\", \",\".join([str(f[i].shape) for i in range(M)]))\n",
"h = linear(f, weights)\n",
"print(\"Output shapes\", \",\".join([str(h[i].shape) for i in range(M)]))\n",
"print_irreps(h)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that we now have multiple channels at each irrep.\n",
"\n",
"Now we'll implement the Clebsch-Gordan nonlinearity, Equation {eq}`cg-nl`. We'll use the coefficients in sympy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sympy.physics.quantum.cg import CG\n",
"from sympy import S\n",
"from functools import lru_cache\n",
"\n",
"\n",
"# to speed-up repeated calls, put a cache around it\n",
"@lru_cache\n",
"def cg(i, j, k, l, m, n):\n",
" # to get a float, we wrap input in symbol (S), call\n",
" # doit, and evalf.\n",
" r = CG(S(i), S(j), S(k), S(l), S(m), S(n)).doit().evalf()\n",
" return float(r)\n",
"\n",
"\n",
"# As you can see, the Clebsch-Gordan nonlinearity is a lot!!\n",
"def cgnl(f):\n",
" output = [np.zeros_like(fi) for fi in f]\n",
" # m,n -> outputs\n",
" for i in range(len(f)):\n",
" for j in range(-i, i + 1):\n",
" for k in range(len(f)):\n",
" for l in range(-k, k + 1):\n",
" for m in range(len(f)):\n",
" for n in range(-m, m + 1):\n",
" output[m][:, n] += (\n",
" f[i][:, j] * f[k][:, l] * cg(i, j, k, l, m, n)\n",
" )\n",
" return output\n",
"\n",
"\n",
"print_irreps(cgnl(h))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can make our complete layer! We won't use a gated nonlinearity here, just the Clebsch-Gordan nonlinearity."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def cg_net(x, W, l, num_layers):\n",
" f = cart2irreps(x, l)\n",
" for i in range(num_layers):\n",
" f = linear(f, W[i])\n",
" f = cgnl(f)\n",
" return np.squeeze(f[0])\n",
"\n",
"\n",
"num_layers = 3\n",
"L = 3\n",
"channels = 4\n",
"weights = (\n",
" [init_weights(1, channels)]\n",
" + [init_weights(channels, channels) for _ in range(num_layers - 2)]\n",
" + [init_weights(channels, 1)]\n",
")\n",
"\n",
"cg_net(points, weights, L, num_layers)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have our irrep features. How do we get an output? If we're trying to output a scalar (regression/classification), we would just take the $l = 0$ irrep fragment. Remember that we have a function with irreps (Equation {eq}`fft`), so getting out a point from a function can be done a few ways. One example is to read out the point at which the function (the product of fragments and spherical harmonics) is maximized. Or you could compute its average via integration.\n",
"\n",
"Let us now check that our network is indeed invariant (we're outputting a single value per point, so invariant). We'll make a rotation and check if our output changes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# random 3x3 matrix\n",
"R = np.random.rand(3, 3)\n",
"# make it a member of SO(3)\n",
"U, _, V = np.linalg.svd(R)\n",
"R = np.dot(U, V)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(cg_net(points, weights, L, num_layers))\n",
"print(cg_net(points @ R, weights, L, num_layers))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*As you can see, something is broken and I need to fix it when I have time.*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Equivariant Neural Networks with Constraints\n",
"\n",
"You do not need to use irreducible representations. It is currently in 2022 the dominant paradigm due to its good accuracy. One alternative is to work in the defining/faithful representation and put equivariant constraints on your network weights. This approach is quite nice because the implementation is independent of the group. It also works for finite groups. Let's see an example of this approach via the library released by the authors called Equivariant MLP (`emlp`){cite}`finzi2021emlp`\n",
"\n",
"We'll create an SO(3) equivariant neural network and check that it is equivariant to rotations. We begin by defining our group and its representation. I'll show a few elements too, to demonstrate that this is the faithful representation and not the irreducible. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from emlp.groups import SO, S\n",
"import emlp.reps as reps\n",
"import emlp\n",
"import haiku as hk\n",
"import emlp.nn.haiku as ehk\n",
"import jax.numpy as jnp\n",
"\n",
"so3_rep = reps.V(SO(3))\n",
"# grab a random group element\n",
"sampled_g = SO(3).sample()\n",
"dense_rep = so3_rep.rho(sampled_g)\n",
"# check its a member of SO(3)\n",
"# g @ g.T = I\n",
"print(dense_rep @ dense_rep.T)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we'll apply our group element to a point to see it rotate the point. The norm should be unchanged, because it's a rotation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"point = np.array([0, 0, 1])\n",
"print(\"new point\", dense_rep @ point.T)\n",
"print(\"norm\", np.sqrt(np.sum((dense_rep @ point) ** 2)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's assume our input function consists of 5 points (e.g., methanol molecule) defined by features (e.g., 1D element embedding) and positions. We'll create that as a direct sum of 5 scalars and 5 vectors. Our output will be a vector (e.g., dipole). Equivariance here will then mean that if rotate the input points, our output vector should undergo the same rotation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_rep = 5 * so3_rep**0 + 5 * so3_rep**1\n",
"print(\"input rep\", input_rep)\n",
"print(\"output rep\", so3_rep)\n",
"\n",
"input_point = np.random.randn(5 + 5 * 3)\n",
"print(\"input features\", input_point[:5])\n",
"print(\"input positions\\n\", input_point[5:].reshape(5, 3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = emlp.nn.EMLP(input_rep, so3_rep, group=SO(3), num_layers=1)\n",
"output_point = model(input_point)\n",
"print(\"output\", output_point)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we'll transform the input points according to a random element in the group. We could convert the input into the five spatial vectors and apply the group element to them individually and put them back together. However, `emlp` has a convenience function for exactly that. We can change our group element to the input representation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trans_input_point = input_rep.rho_dense(sampled_g) @ input_point\n",
"print(\"transformed input features\", trans_input_point[:5])\n",
"print(\"transformed input positions\\n\", trans_input_point[5:].reshape(5, 3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we compare running the transformed input through the model against applying the group element to the output from the untransformed input. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model(trans_input_point), sampled_g @ output_point"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Indeed they are equivalent -- meaning this model is equivariant. The constraint approach is quite simple to use and can handle arbitrary groups. However, it may not be efficient when working with many input points (like a protein) and it may make sense to use an implementation specific to E(3) or SO(3). "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How the constraints work\n",
"\n",
"How does this magic happen? Rather than explicitly setting constraints on the dense layer weights, `emlp` always first projects the network weights into an **equivariant subspace.** This means that the cost of equivariance is paid when constructing the model when this projection matrix is found but not later during training and inference. The equivariant subspace is the space of allowed weights that respect the equivariance. Let's see what this looks like.\n",
"\n",
"Recall that a dense layer has the equation:\n",
"\n",
"\\begin{equation}\n",
"y = \\sigma\\left(Wx + b\\right)\n",
"\\end{equation}\n",
"\n",
"where $\\sigma$ is a special nonlinearity for equivariant neural networks we won't discuss here (see {cite}`weiler20183d`). To respect the equivariance, $W,b$ will need to be projected into an equivariant subspace that depends on our group and input/output representations. So our modified equation would look like:\n",
"\n",
"\\begin{equation}\n",
"y = \\sigma\\left(P_wWx + P_bb\\right)\n",
"\\end{equation}\n",
"\n",
"Let's start by making these projectors. $P_b$ only will need to consider the output rep, since $b$ is the bias (same representation as output)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Pw = (input_rep >> so3_rep).equivariant_projector()\n",
"Pb = (so3_rep).equivariant_projector()\n",
"\n",
"print(\"Pw shape is\", Pw.shape, \"Pb shape is\", Pb.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that they are square because they should leave the underlying dimension of $W$ unchanged -- we are not projecting to a reduce dimension, but a subspace within the space of possible values of the weights. Remember too our representations are flattened - that 60 comes from the fact that our weight matrix is $3\\times(5 + 15)$.\n",
"\n",
"Now let's show how these projectors can convert an arbitrary weight matrix into one that is equivariant. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"W = np.random.randn(3, 5 + 5 * 3)\n",
"b = np.random.randn(3)\n",
"\n",
"print(\n",
" \"W is not alone equivariant\",\n",
" W @ trans_input_point.flatten(),\n",
" \"!=\",\n",
" sampled_g @ W @ input_point,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Proj_W = (Pw @ W.flatten()).reshape(W.shape)\n",
"print(\n",
" \"Projected W is equivariant\",\n",
" Proj_W @ trans_input_point.flatten(),\n",
" \"==\",\n",
" sampled_g @ Proj_W @ input_point,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You may be wondering how much the projection affects $W$. Is there enough flexibility that you can learn? We can compare the full *random* matrix $W$ vs it's projection."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.title(\"Random W\")\n",
"plt.imshow(W)\n",
"plt.show()\n",
"plt.title(\"Projected W\")\n",
"plt.imshow(Proj_W)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It appears that there are only a few unique values in $W$ after projection, so that our weight space is effectively much lower dimensional. This is why it's important to have multiple channels! This also demonstrates why `emlp` can be more expensive. We're training 180 values but we could have just used a few. Similarly, the projected bias is zero for our system."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Pb @ b"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Including Permutation Groups\n",
"\n",
"In real molecules, we also need to have permutation equivariance with respect to the atom ordering and bond ordering -- which is not true of our above example about computing dipole moment. `emlp` also supports permutation groups, which are usually written as $S_n$, where $n$ is the number of permutable elements in the group. We'll work on that in the next chapter."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chapter Summary\n",
"\n",
"* Equivariant neural networks guarantee equivariance by construction for arbitrary groups, which removes the need to align trajectories, work in special coordinate systems, or use pairwise distances. \n",
"* Equivariance can be achieved by parameter sharing or testing/training data augmentation, but here we focused on equivariant layers that can be composed into a neural network. \n",
"* Equivariance requires definition of a group and homogeneous space. We must view our input data as functions and our models as operators that transform functions.\n",
"* Finite groups can be treated with G-equivariant layers that have an additional sum across the number of group elements.\n",
"* Infinite groups like SO(3) can be made finite by working with a direct sum (list of vectors) of the irreducible representations. This requires converting the input data though to the irreducible representation and there are complexities in nonlinearities and implementations typically must be written per-group.\n",
"* Constraint-based equivariant layers are flexible, general, and quick to implement but may not scale well with respect to size of input group or number of points.\n",
"* Recent work has also shown you can put irreducible representation direct sums into the edges of graph neural networks, gaining input size independence, permutation invariance, and spatial equivariance in one model. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Relevant Videos\n",
"\n",
"### Intro to Geometric Deep Learning\n",
"\n",
"\n",
"\n",
"### Equivariant Networks\n",
"\n",
"\n",
"\n",
"\n",
"### Equivariant Network Tutorial\n",
"\n",
"[watch here](https://slideslive.com/38943570/equivariant-networks)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"from myst_nb import glue\n",
"\n",
"\n",
"def hexico(rot):\n",
" cmap = plt.get_cmap(\"Set2\")\n",
" colors = [cmap(i / 6) for i in range(6)]\n",
" points = np.array(\n",
" [\n",
" (0, 1),\n",
" (0.5 * np.sqrt(3), 0.5),\n",
" (0.5 * np.sqrt(3), -0.5),\n",
" (0, -1),\n",
" (-0.5 * np.sqrt(3), -0.5),\n",
" (-0.5 * np.sqrt(3), 0.5),\n",
" ]\n",
" )\n",
" # wrap the points\n",
" index = [(i + rot) % 6 for i in range(6)]\n",
" points = points[index]\n",
" plt.figure(figsize=(1, 1))\n",
" if rot == 6:\n",
" plt.scatter(points[:, 0], points[:, 1], marker=\".\", color=\"black\", s=150)\n",
" else:\n",
" plt.scatter(points[:, 0], points[:, 1], marker=\"o\", c=colors, s=150)\n",
" plt.plot([0, points[0, 0]], [0, points[0, 1]], color=\"black\", zorder=0)\n",
"\n",
" plt.plot(points[:, 0], points[:, 1], color=\"black\", zorder=0)\n",
" plt.plot(\n",
" [points[-1, 0], points[0, 0]],\n",
" [points[-1, 1], points[0, 1]],\n",
" color=\"black\",\n",
" zorder=0,\n",
" )\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.xlim(-1.4, 1.4)\n",
" plt.ylim(-1.4, 1.4)\n",
" for s in plt.gca().spines.values():\n",
" s.set_visible(False)\n",
" glue(f\"hex-{rot}\", plt.gcf(), display=False)\n",
"\n",
"\n",
"for i in range(7):\n",
" hexico(i)\n",
"\n",
"\n",
"def quad(rot, mirror):\n",
" colors = color_cycle[:4]\n",
" points = np.array([(1, 1), (-1, 1), (-1, -1), (1, -1)])\n",
" if mirror:\n",
" points[:, 0] *= -1\n",
" # wrap the points\n",
" index = [(i + rot) % 4 for i in range(4)]\n",
" points = points[index]\n",
" plt.figure(figsize=(1, 1))\n",
" plt.scatter(points[:, 0], points[:, 1], marker=\"o\", c=colors, s=150)\n",
" plt.plot([0, points[0, 0]], [0, points[0, 1]], color=\"black\", zorder=0)\n",
"\n",
" plt.plot(points[:, 0], points[:, 1], color=\"black\", zorder=0)\n",
" plt.plot(\n",
" [points[-1, 0], points[0, 0]],\n",
" [points[-1, 1], points[0, 1]],\n",
" color=\"C0\",\n",
" zorder=0,\n",
" )\n",
" plt.plot([0, points[1, 0]], [0, points[1, 1]], linestyle=\"-\", color=\"C1\", zorder=0)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.xlim(-1.4, 1.4)\n",
" plt.ylim(-1.4, 1.4)\n",
" for s in plt.gca().spines.values():\n",
" s.set_visible(False)\n",
" print(rot, mirror)\n",
" glue(f\"quad-{rot}-{mirror}\", plt.gcf(), display=False)\n",
"\n",
"\n",
"for i in range(4):\n",
" for j in range(2):\n",
" quad(i, j)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"colors = color_cycle[:4]\n",
"points = np.array([(1, 1), (-1, 1), (-1, -1), (1, -1)])\n",
"index = range(4)\n",
"points = points[index]\n",
"plt.figure(figsize=(1, 1))\n",
"plt.scatter(points[:, 0], points[:, 1], marker=\"o\", c=colors, s=150)\n",
"plt.plot([0, points[0, 0]], [0, points[0, 1]], color=\"black\", zorder=0)\n",
"\n",
"plt.plot(points[:, 0], points[:, 1], color=\"black\", zorder=0)\n",
"plt.plot(\n",
" [points[-1, 0], points[0, 0]], [points[-1, 1], points[0, 1]], color=\"C0\", zorder=0\n",
")\n",
"plt.plot([0, points[1, 0]], [0, points[1, 1]], linestyle=\"-\", color=\"C1\", zorder=0)\n",
"plt.xticks([])\n",
"plt.yticks([])\n",
"plt.xlim(-1.4, 1.4)\n",
"plt.ylim(-1.4, 1.4)\n",
"for s in plt.gca().spines.values():\n",
" s.set_visible(False)\n",
"plt.savefig(\"quad.svg\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cited References\n",
"\n",
"```{bibliography}\n",
":style: unsrtalpha\n",
":filter: docname in docnames\n",
"```"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}