I've discovered two interesting things about prompt tuning: https://arxiv.org/abs/2104.08691
For anyone new or living under a rock, NovelAI has been using prompt tuning to create modules that let users essentially finetune their massive language model without changing its parameters. A module is basically tokens with trainable embeddings that are prefixed to the input to steer its generation. You freeze all the weights of the language model and then only train the module tokens on a dataset like you would normally do finetuning. By doing this you can achieve the same results as model finetuning, without changing any of the language model weights. You can train hundreds of these modules for different characters, moods or writing styles and it'll only cost a few MB rather than duplicating a 6 GB model 100s of times.
It's similar to the vision encoder tokens in the paper mentioned here (it was actually motivated by prompt tuning): >>11731
So here's what I've found so far:
1) Taking inspiration from MMD-VAE transformers, you can use an autoencoding transformer like T5-v1_1-base to encode the input tokens[..., :-1] into a prefix, then set all the labels to -100 (to be ignored during training using Hugging Face) except the last one you're trying to predict. The performance of GPT-2 becomes super enhanced (8 to 40 perplexity point improvement after an hour of training).
I have no idea yet why this is so effective. The weights of GPT-2 are frozen during training and GPT-2 still generates fine with the prefix even when not using this specific token position trained on. Vanilla GPT-2 without the prefix often gets stuck looping but with the prefix it continues generating as well as the large GPT-2 model. Training on all the tokens also seems to work but is much slower and only slightly improves so I didn't explore this too much.
I also tried testing how it did on an additional 32 tokens after the single token it was training on and the perplexity still had an improvement of 8 without training. I increased this to 256 and it was still 2 perplexity better without training and quickly improved to 5 after a few optimizer steps, and by 7 after 20 steps and 10 after 35 steps, and 11 by 56 steps. The T5 encoder did not see these additional tokens at all, so it seems the GPT-2 tranformer is performing some sort of calculation with the initial tokens in the prompt but then is able to stabilize itself.*
I'm really curious what's actually going on in the transformer that causes it to forget how to generate the initial prompt (~7 points worse in perplexity) but then suddenly get the generated tokens after that to be so good and remain stable and interesting without repeating itself.
2) You can do a similar thing encoding the previous context into a prefix, using it as a compressed memory of the previous context. This also improves GPT-2's performance by about 5 points when training on all tokens for a few hours and it will include information from the previous context during generation. It also seems to benefit from training only the last token. Planning to explore this more later.
While doing these experiments I used a memory length of 32 tokens, an input size of 256 tokens (not including the memory), using a total batch size of 1024 with gradient accumulation.
What if previously generated prefixes are included in the prefix generation too? This could potentially allow information to flow from tens of thousands of tokens ago.
What if a second prefix is added that compresses all the previous prefixes concatenated together? This could function like a summary of the past 32k tokens. Modules are generally incompatible but these two prefixes would be trained together.
Is it possible to add a memory controller so the transformer can read and write these memories?
What is actually going on with prompt tuning, memory prefixes and vision encoder tokens? Where do they exist in the embedding space relative to the actual vocabulary embeddings and each other?
What do the individual losses for additional tokens and the inital prompt look like after training on only the last token for a long time? Which dimensions of the embeddings are causing the improvements? Graphing these might provide some insight into the calculations the transformer is doing.
Do these performance gains scale to larger models, such as gpt2-medium that can run on a consumer GPU? Could it help with distilled GPT-2 which has a major problem with looping?
: If the transformer is performing a useful calculation with the initial prompt, is it possible to create some sort of wormhole with a token that continues doing this calculation for a few tokens then returns back, replacing the real token embedding with the calculated output?
So many questions, I feel like a huge breakthrough is around the corner.