Tensor Calculus and the Tensor Chain Rule


  1. Overview
  2. Tensors Preliminary
    1. Tensor Product
    2. Tensor Product Contraction
  3. Tensor Calculus
    1. Product Rule
    2. Chain Rule
    3. Examples

Overview

In my own deep-learning research I often find myself needing to compute derivatives of vector or scalar valued functions with repect to matrices. Looking on the internet, however, I find there is a ton of inconsistent and confusing notation related to tensor and matrix calculus and people often rely on heuristic rules to compute these more complex derivatives. For example, it's common to say things like \(\nabla_{W} Wx = x^T\), which really makes no sense since the left side is a three-tensor while the right side is a row vector that is a 1-tensor. Effectively, however, \(\nabla_W W x\) behaves like \(x^T\).

In this article, I'll try to formally write out a "tensor chain rule" and show how it can be used to compute common complicated derivatives. I'll also give some of my intuition. Basically, I want to write up the main approaches I use to approach these complicated derivatives for future reference.


Tensors Preliminary

A scalar \(\xi\) is a zero tensor and a vector \(v\) is a one tensor. A matrix is effectively a two tensor. Operations on these such as dot product between vectors, matrix-vector products and matrix-matrix products can be defined through a unifying notion of a "tensor contractions" (below).

In general, an \(m\)-tensor over a field \(\mathbb{F}\) (such as the real or complex numbers) looks like an element of \(\mathbb{F}^{n_1 \times n_2 \times \cdots \times n_m}\) for numbers \(n_1, ..., n_m \geq 1\). We can denote such a tensor by \(T\). To simplify things, it's convenient to define a multi-index \(alpha\) which is a tuple of numbers, e.g. using a multi-index \(\alpha = (n_1, ..., n_m)\), we can write $$T \in \mathbb{F}^{\alpha} = \mathbb{F}^{n_1 \times \cdots \times n_m}.$$ Indexing the tensor uses any other multi-index \(\beta = (p_1, ..., p_m)\) where \(\beta \leq \alpha\), meaning \(1 \leq p_i \leq n_i\) for all \(i=1...m\). Each entry of \(T\) is a scalar of the form $$T_\beta := T_{p_1, p_2, ..., p_m} \in \mathbb{F}.$$

Tensor Product

Now, tensors are typically built from the ground up by sums of tensor products, which we now define. The tensor product generalizes the idea of an outer product on vectors. In particular, when we have a vector \(v \in \mathbb{F}^n\) and a vector \(w \in \mathbb{F}^m\) their outer product is $$v \cdot w^T \in \mathbb{F}^{n \times m}.$$ So, the outer product takes two one tensors and constructs a two tensor. Note that $$(v \cdot w^T)_{i,j} = v_i \cdot w_j,$$ i.e. the entries are multiplied entry-wise. Let \(|\) denote concatenation of two multi-indices. In general, we define the tensor product on tensors \(T \in \mathbb{F}^\alpha, S \in \mathbb{F}^\beta\) as a tensor \(T \otimes S \in \mathbb{F}^{\alpha | \beta}\) with entries $$(T \otimes S)_{\gamma | \delta} = T_\gamma \cdot S_\delta.$$ Without the multi-index notation, assuming \(\alpha\) is length \(m\), \(\beta\) is length \(\ell\), we can write this as $$(T \otimes S)_{p_1, ..., p_m, \, q_1, ..., q_\ell} = T_{p_1, ..., p_m} \cdot S_{q_1, ..., q_\ell}.$$ A general tensor \(T\) can always be written in terms of tensor products. For example, if \(e_i\) is the \(i\)th standard unit vector, define \(e_\delta\) to be the tensor product of \(e_i\)'s corresponding to a multi index \(\delta\). Then, if \(T \in \mathbb{F}^\alpha\), $$T = \sum_{\delta \leq \alpha} T_\delta \cdot e_\delta,$$ i.e. we write \(T\) in standard coordinates.

Tensor Product Contraction

The tensor product behaves like an outer product. Similarly, a pairwise contraction on two tensors behaves like an inner product. Given two vectors, the inner product componentwise the two vectors them sums along that axis: $$v \cdot w = \sum_i v_i w_i.$$ We can write this as a sum along the diagonal elements of the outer (tensor) product: $$v \cdot w = \sum_i (v \otimes w)_{i,i}.$$ Suppose we have a tensor \(T \in \mathbb{F}^{\alpha | \gamma}\) and a tensor \(S \in \mathbb{F}^{\gamma | \beta}\), then we can define the tensor product contraction as $$T \otimes_\gamma S := \sum_{\delta \leq \gamma} (T_{...,\delta} \otimes S_{\delta,...}) \in \mathbb{F}^{\alpha | \beta},$$ where \(T_{...,\delta}\) means the sub-tensor in \(\mathbb{F}^\alpha\) with the last \(\gamma\) indices chosen and similar for \(S_{\delta,...}\). In other words, we pair up the \(\gamma\) parts and multiply them elementwise then sum all them up. The product can be read as "tensor product of \(T\) and \(S\) contracted over \(\gamma\)." We can define other contractions on tensors which are commonly used in packages like numpy such as summing along an axis of a single tensor, but the one above is all that's needed for the tensor chain rule.

For example, suppose \(T \in \mathbb{R}^{10 \times 4 \times 4}, S \in \mathbb{R}^{4 \times 4 \times 8 \times 7}\). Then, we can take the tensor product contraction $$(T \otimes_{(4, 4)} S)_{p,q,r} = \sum_{i=1}^4 \sum_{j=1}^4 T_{p,i,j} S_{i,j,q,r} \in \mathbb{R}.$$ The tensor \(T \otimes_{(4,4)} S\) is in \(\mathbb{R}^{10 \times 8 \times 7}\).


Let's now describe three examples:


Vector inner product: For two vectors, \(v\), \(w\) in \(\mathbb{F}^n\), $$v \cdot w = (v \otimes_n w) \in \mathbb{F}.$$


Matrix-vector product: For a matrix \(W \in \mathbb{F}^{m \times n}\) and a vector \(v \in \mathbb{F}^n\), $$W \cdot v = (W \otimes_n v) \in \mathbb{F}^m.$$


Matrix-matrix product: For two matrices \(A \in \mathbb{F}^{m \times n}\) and a vector \(B \in \mathbb{F}^{n \times k}\), $$A \cdot B = (A \otimes_n B) \in \mathbb{F}^{m \times k}.$$

(Note that the notation \(n\) is shorthand for the length one multi index \((n)\)).

Tensor Calculus

I'll now describe how to do calculus with these tensors, specifically the chain rule. This has turned out to be very useful for me since it gives a more formal way of computing thingsuch as a "derivative of a vector valued function with respect to a matrix." I'll show how some of the intuitive facts that people take for granted typically are special cases of the tensor chain rule.

Suppose \(G\) is a tensor-valued function taking tensor inputs: $$G: \mathbb{F}^\alpha \rightarrow \mathbb{F}^\beta.$$ Then, the Jacobian or total derivative, of \(G\) with respect to input \(T \in \mathbb{F}^\alpha\) is a tensor \(D_T G \in \mathbb{F}^{\beta | \alpha}\) defined by $$(D_T G)_{\delta | \gamma} = \frac{d G(T)_\alpha}{d T_\delta},$$ for specific indices \(\gamma \leq \alpha, \delta \leq \beta\). Unforunately the greek symbols make the formula above look more complex than it should be. Think of \(\alpha, \beta, \delta, \gamma\) as just \(n, m, i, j\) and pretty much everything looks the same as the Jacobian matrix.

Product Rule

A tensor product rule holds although it's difficult to think of due to non-commutativity. First, note for motivation a simple case. Suppose \(f(v), g(v)\) are vector valued function in \(\mathbb{R}^m\) and \(v \in \mathbb{R}^n\). Then, you can check that $$D_v \langle f(v), g(v) \rangle = ((D_v f(v))^T g(v))^T + f(v)^T (D_v g(v)),$$ where \(\langle \cdot \rangle\) is the inner product. Note the formula requires a transpose on the first term to get the matrix-vector product to work. We see the same with tensors.

First, for a tensor \(T \in \mathbb{F}^{\alpha | \beta}\), define the tensor transpose by $$T^{tr(\beta)} \in \mathbb{F}^{\beta | \alpha}, \, \, T^{tr(\beta)}_{\delta, \gamma} = T_{\gamma, \delta},$$ for all \(\delta \leq \beta, \gamma \leq \alpha\).

Now, we can define the product rule. In particular, suppose: $$G: \mathbb{F}^\alpha \rightarrow \mathbb{F}^{\beta | \gamma}, H: \mathbb{F}^\alpha \rightarrow \mathbb{F}^{\gamma | \delta},$$ then, $$D_T (G(T) \otimes_{\gamma} H(T)) = ((D_T G)^{tr(\alpha)} \otimes_{\gamma} H)^{tr(\beta | \delta)} + G \otimes_{\gamma} D_T H.$$ Note \(D_T H \in \mathbb{F}^{\gamma | \delta | \alpha},\) so \(G \otimes_{\gamma} D_T H \in \mathbb{F}^{\beta | \delta | \alpha}\).

Likewise, \((D_T G)^{tr(\alpha)} \in \mathbb{F}^{\alpha | \beta | \gamma}\), so \((D_T G)^{tr(\alpha)} \otimes_{\gamma} H \in \mathbb{F}^{\alpha | \beta | \delta}\), and finally the \(tr(\beta | \delta)\) makes the dimensions \(\beta | \delta | \alpha\) as with the other term. Jeez....

Chain Rule

Now, suppose we have $$G: \mathbb{F}^\alpha \rightarrow \mathbb{F}^\beta, H: \mathbb{F}^\beta \rightarrow \mathbb{F}^\nu,$$ and we want to find the total derivative of the composed function \(H(G(T))\). Then, the tensor chain rule is as follows $$D_T H(G(T)) = D_{G(T)} H \otimes_{\beta} D_T G.$$

Examples


Derivative of \(W x\) with respect to \(W\):


Suppose \(H(W, x) = Wx\) where \(W \in \mathbb{R}^{n \times n}, x \in \mathbb{R}^n\). Let \(Id : \mathbb{R}^{n \times n} \times \mathbb{R}^{n \times n}\) be the identity function. Then, $$W x = Id(W) \otimes_n x,$$ so by the tensor product rule, assuming \(D_W x = 0\), i.e. \(x\) doesn't depend on \(W\), $$D_W(W x) = ((D_W Id(W))^{tr(n,n)} \otimes_n x)^{tr(n)}.$$ Let \(E_\alpha\) denote the ``tensor identity'' given by $$E_\alpha = \sum_{i=1}^{\min(\alpha)} e_i \otimes e_i \otimes \cdots \otimes e_i,$$ i.e. the sum of all diagonal outer products of standard unit vectors. Then, $$D_W(W x) = ((E_{n,n,n,n})^{tr(n,n)} \otimes_n x)^{tr(n)}$$ $$= (E_{n,n,n,n} \otimes_n x)^{tr(m)}$$ $$= (E_{n,n} \otimes x)^{tr(n)}$$ $$= E_{n,n} \otimes x.$$ The matrix \(E_{n,n}\) is essentially just the identity matrix. We computed the contracted product with: $$E_{\alpha |n} \otimes_{n} x = \sum_{i=1}^n x_i (e_i \otimes \cdots \otimes e_i) = (\sum_i x_i e_i) \otimes \cdots \otimes e_i = x \otimes E_\alpha.$$ It's clear then that \(x \otimes E_{n,n} = E_{n,n} \otimes x\). Finally, note that given a vector \(v \in \mathbb{R}^n\), $$D_W(W x) \otimes_n v = (E_{n,n} \otimes x) \otimes_n v = \langle x, v \rangle E_{n,n},$$ i.e. the Jacobian behaves like \(x^T\)!


goto: main
goto: top