0scar Chang 晴れ男

Can we Recover a Neural Network Exactly?

Given oracle access to a blackbox neural network, is it possible to recover its weights (up to equivalence of computing the same function) exactly? We do not know the answer, but note down some observations regarding the problem.

The Problem

Given oracle access to a function \(f_\theta\) computed by a feedforward neural network of known architecture with non-linearity \(\sigma=ReLU\) and parameters \(\theta\), can we use an algorithm to find a \(\theta^*\) such that \(f_{\theta^*} = f_\theta\)?

Oracle access means that we can propose queries \(x_1, x_2, \dots\) and get back answers \(y_1 = f(x_1), y_2 = f(x_2), \dots\).

Observation 1: Solving this with SGD gives us a pretty good approximation

We can initialize \(\theta^*\) randomly, feed it with random inputs, and use the loss function \(Distance(f_{\theta^*}(x) - f_\theta(x))\) as a training signal to update \(\theta^*\) with SGD. This will provide us with a pretty good approximation.

But our objective in this case is to find \(\theta^*\) such that we recover the original function exactly, i.e. where we can prove that indeed \(f_{\theta^*} = f_\theta\).

Observation 2: It is trivial to solve this for \(n\)-ary variables

Suppose the network computes binary variables, then every \(f_\theta\) represents a boolean function. We can simply enumerate all possible combinations of inputs to the boolean function as a truth table, and find a \(\theta^*\) which is satisfiable.

The same is true if the network computes \(n\)-ary variables for any \(n\). This cannot be done for real variables though, because it is impossible to enumerate the entire input space of the boolean function.

Observation 3: This cannot be done in the limit.

The Universal Approximation Theorem says that a feedforward neural network with a single hidden layer that is sufficiently wide can approximate any continuous function on a compact subset of the real numbers, given a nonconstant, bounded, and continuous activation function (which does not really hold for ReLU).

In general, if we are trying to recover an arbitrary real function exactly, we will need an infinite number of queries.

Attempted Solution

When my advisor first posed this problem, my instinctual reaction was that it cannot be solved. After all, there is no general way of solving a non-linear equation analytically. Furthermore, if this can be done, why would we not simply use this method instead of SGD to optimize neural networks?

But upon thinking it through, it isn’t immediately clear that we cannot solve this, since this is not a general non-linear equation, but a very specific one with ReLUs. An algorithm for solving this also wouldn’t necessarily help in optimizing neural networks, since in general, we only have access to a set of labeled data points and not oracle access to the “optimal” function.

If we wanted to answer this problem in the negative, we would have to explain how the ability to solve this problem would make some known difficult problem easy. This seems like a complexity/learning/security theory question, which is beyond my current range of expertise. Many theoreticians also typically stick to boolean circuits instead of real-valued functions, so I wasn’t able to google for anything relevant to solving this problem, even though it’s highly likely that someone else has already posed this problem and answered it.

I wasn’t able to solve the problem in general, but I note down specific cases below where it is possible to find solutions.

We assume no biases, and start with the linear case.

Case 1: \(f(x) = Wx\)

We write \(X\) as a matrix containing the queries \(x_i\) as the column vectors. Then \(Y = WX\) is a matrix containing the answers \(y_i\) as column vectors, and we can recover the weights \(W = YX^{-1}\) directly.

(If there were biases in the linear case, we can set \(x=0\) to recover them directly.)

Case 2: \(f(x) = \sigma(Wx)\)

\(\sigma = ReLU\) implies that the negative entries of \(WX\) will result in \(0\)s in \(\sigma(WX)\). \(\sigma(-WX)\) thus recovers the previously zeroed out entries while zeroing out previously positive entries.

We write \(Y^{+} = \sigma(WX)\) and \(Y^{-} = \sigma(W(-X))\).

Then, \(W = (Y^{+} + Y^{-})X^{-1}\), by reduction to Case 1.

Case 3: All the weights \(W_i\) are scalar.

Let’s start with the example of a two layer network: \(f(x) = \sigma(w_1 \sigma(w_2 x))\).

If \(w_1\) is non-positive, then \(f(x)\) is just the zero function. Otherwise, we can determine the sign assignment of \(w_2\) by assuming every possible sign assignment (\(-,+,0\)) and checking if it’s satisfiable with the sign assignments of \(x\) and \(y\).

Knowing the sign of \(w_2\) helps us remove the inner \(\sigma\), and thus figure out the unique product \(w_1 w_2\). Then, our \(\theta^*\) can be any values for \(w_1, w_2\) that satisfy this unique product and the above sign assignment.

In general, \(f(x) = \sigma(w_1 \sigma( \dots \sigma(w_n x) \dots ))\). As above, if any of \(w_i\) for \(i \in [1, n-1]\) is non-positive, then \(f(x)\) is just the zero function. Otherwise, we can determine the sign assignment of \(w_n\) by assuming every possible sign assignment (\(-,+,0\)) and checking if it’s satisfiable with the sign assignments of \(x\) and \(y\).

Knowing the sign of \(w_n\) helps us remove the innermost \(\sigma\), and thus figure out the unique product \(\prod_{i=1}^n w_i\). Then, our \(\theta^*\) will be any value assignment to \(\{w_i\}_{i=1}^n\) that satisfies this unique product and the above sign assignment.

This general two-step process of figuring out the sign assignments and then the value assignments of the weights cannot be applied in general to matrix-shaped weights \(W_i\). This is because we can end up with cases where it is possible to establish that certain entries \(w_i, w_j\) in a weight matrix have opposite signs, but not be able to tell exactly which entry is positive. This under-determination does not imply that we cannot solve the problem in general, since the weights we are trying to find are under-determined in the first place, i.e. there can be many \(\theta^*\) such that \(f_{\theta^*} = f_\theta\).

Case 4: \(f(x) = W_1 \sigma(W_2 x)\)

This is where I got stuck and am not able to make further progress for now.

In practice, it seems like one ought to be able to write a program to enumerate various possible cases, say for dimension \(2\). But then, one also has to prove that this program recovers the \(\theta^*\) such that \(f_{\theta^*} = f_\theta\) exactly.

Update (Sep 2019)

It seems like there are theory results that say that two layer neural networks are learnable, but that the general n-layer case is not.