Inspirative Text Prediction

Published

December 12, 2023

Introduction

Machine learning, and more specifically, deep learning, is shaping how we write. From academic papers to class materials, emails to text messages, we are constantly using technologies powered by deep learning to compose our texts. Moreover, studies have shown that predictive text influences what we write [1,2,4]. Currently, most text prediction technology uses a model that looks at the previously typed words and the surrounding text to generate a list of likely next words or phrases. It ranks each of them based on their probabilities and presents the most likely ones to users as suggestions. However, not only may those suggestions be biased, but they may also affect how users write and what they write, thereby taking away their authorship and autonomy. Could text prediction models instead serve as a source of inspiration for users, encouraging their writing process instead of suggesting what to write?

In this blog post, I will explore the possibility of using text prediction to inspire users to write more original texts. I will define what it means to be inspirational and then present a preliminary approach to collecting example data and evaluating the current large language models (LLMs) to determine their likelihood of predicting subordinating conjunctions. I will then discuss the challenges and opportunities of using text prediction to inspire users to write more original texts.

Exploratory Data Analysis

import plotly.io as pio
import plotly.express as px
import pandas as pd
import requests
import spacy

pio.templates.default = "plotly_white"

spacy.prefer_gpu()
nlp = spacy.load("en_core_web_sm")

Collection

Let’s start by defining a function to download a book from Project Gutenberg. To accomplish this, we will use Gutendex to retrieve the book’s metadata and then download the book using the URL to the plain text version of the book provided in the metadata. For the purpose of this blog post, we will only download books in English.

#| code-fold: True

def download_book(book_id: int) -> tuple[str, str]:
    """Download a book from Project Gutenberg

    Arg:
        book_id: The Project Gutenberg ID of the book to download

    Returns:
        A tuple containing the book title and the book text
    """

    gutendex_url = f"https://gutendex.com/books/{book_id}/"

    try:
        response = requests.get(gutendex_url)
        response.raise_for_status()
        data = response.json()

        book_language = data["languages"]

        # Only download books in English
        if "en" in book_language:
            book_title = data["title"]

            # Only download books in plain text
            mime_types = ["text/plain", "text/plain; charset=us-ascii"]

            for mime_type in mime_types:
                if mime_type in data["formats"]:
                    book_url = data["formats"][mime_type]
                    break

            if book_url is None:
                raise Exception("The book is not available in plain text.")

            response = requests.get(book_url)
            response.raise_for_status()

            return book_title, response.text
        else:
            raise Exception("The book is not in English.")
    except requests.exceptions.HTTPError as err:
        raise Exception(err)

For this EDA, we will download The Strange Case of Dr. Jekyll and Mr. Hyde by Robert Louis Stevenson.

# Book ID for The Strange Case of Dr. Jekyll and Mr. Hyde
book_id = 43

# Download the book and store it in a DataFrame
book_data = [download_book(book_id)]
book_data = pd.DataFrame(book_data, columns=["title", "text"])

Wrangling

Let’s take a look at the downloaded text:

#| echo: false

# Print the first 256 characters of the book
print(book_data["text"].iloc[0][:256].strip() + "\n\n...\n")

# Print the last 256 characters of the book
print(book_data["text"].iloc[0][-256:].strip(), end="")

It looks like the text contains some extra information which we do not wish to include in our analysis. Let’s remove the extra information and save the cleaned text in a new column.

Specifically, we will use the markers provided by Project Gutenberg to remove the extra information. These markers appear as follows:

*** START OF THE PROJECT GUTENBERG EBOOK …

*** END OF THE PROJECT GUTENBERG EBOOK …

#| code-fold: true

def sanitize_text(text: str) -> str:
    """Remove extra information from the text

    Arg:
        text: The text to sanitize

    Returns:
        The sanitized text
    """

    start_marker = "***"
    end_marker = "*** END OF THE PROJECT GUTENBERG EBOOK"

    # Index of the second occurrence of the start marker
    start_index = text.find(start_marker, text.find(start_marker) + 1)

    # Index of the first occurrence of the end marker
    end_index = text.find(end_marker)

    # Remove the extra information based on the marker indices
    if start_index != -1 and end_index != -1:
        text = text[start_index + len(start_marker) : end_index].strip()

    return text
# Sanitize the text and store it in a new column
book_data["clean_text"] = book_data["text"].apply(sanitize_text)

Let’s take a look at the cleaned text:

#| echo: false

# Print the first 256 characters of the book
print(book_data["clean_text"].iloc[0][:256].strip() + "\n\n...\n")

# Print the last 256 characters of the book
print(book_data["clean_text"].iloc[0][-256:].strip())

This looks much better! Our next step is to split the text into sentences to analyze it at the sentence level. We will use spaCy to do this:

#| code-fold: true

def sentence_spliter(text: str) -> list[str]:
    """Split the text into sentences

    Arg:
        text: The text to split

    Returns:
        A list of sentences
    """

    pipe_disable = ["ner", "lemmatizer", "textcat"]

    # Remove line breaks and split the text into sentences
    doc = nlp.pipe([text.replace("\r\n", " ")], disable=pipe_disable)

    # Return a list of sentences without leading and trailing whitespace
    return [sent.text.strip() for doc in doc for sent in doc.sents]
# Split the text into sentences and store them in a DataFrame
sentences = sentence_spliter(book_data["clean_text"].iloc[0])
sentences = pd.DataFrame(sentences, columns=["sentence"])

sentences.tail()

How many sentences are there in the book?

#| echo: false

print(f"There are {len(sentences)} sentences in the book.")

How many sentence use subordinating conjunctions? In order to answer this question, we will use spaCy’s part-of-speech tagger to identify sentences that contain subordinating conjunctions:

#| code-fold: true

def doc_pipe(sentence: str):
    pipe_disable = ["ner", "lemmatizer", "textcat"]
    return list(nlp.pipe([sentence], disable=pipe_disable))


def has_sconj(sentence: str):
    """Check if a sentence contains a subordinating conjunction

    Arg:
        sentence: The sentence to check

    Returns:
        A Pandas Series containing a boolean value indicating whether the sentence contains a subordinating conjunction and the subordinating conjunction if it exists
    """

    doc = doc_pipe(sentence)

    # Check if the sentence contains a subordinating conjunction
    for token in doc[0]:
        if token.pos_ == "SCONJ":
            return pd.Series([True, token.text])

    return pd.Series([False, None])
#| code-overflow: wrap

# Check if the sentence contains a subordinating conjunction and store the result in a new column
sentences[["has_sconj", "sconj"]] = sentences["sentence"].apply(has_sconj)

# Sanity check
assert sentences["has_sconj"].value_counts().sum() == len(sentences)

sentences.tail()

How many of the sentences contain subordinating conjunctions? How many of the sentences do not contain subordinating conjunctions?

#| echo: false

print(
    f"There are {len(sentences[sentences['has_sconj']])} sentences with a subordinating conjunction,\nand {len(sentences[~sentences['has_sconj']])} sentences without a subordinating conjunction."
)

Visualization

Let’s try visualizing one of the sentences that contains a subordinating conjunction:

Figure 1. Visualization of a Sentence That Contains a Subordinating Conjunction

#| code-fold: true

# Grab a sentence that contains a subordinating conjunction
sentence_id = 1149
doc = nlp(sentences["sentence"].iloc[sentence_id])

# Visualize the sentence using displaCy
spacy.displacy.render(doc, style="dep", jupyter=True, options={"distance": 110})

What about the distribution of subordinating conjunctions in the book?

#| code-fold: true

# Lower case the subordinating conjunctions and count them
sent_sconj = sentences["sconj"].str.lower().value_counts().reset_index()

# Plot the distribution of subordinating conjunctions
fig = px.bar(
    sent_sconj,
    x="sconj",
    y="count",
    title="<b>Figure 2.</b> Distribution of Subordinating Conjunctions",
    labels={"sconj": "Subordinating Conjunction", "count": "Count"},
    color_discrete_sequence=px.colors.qualitative.Safe
)

fig.show()

Analysis

This result is somewhat surprising to me. I did not expect “that” to be the most common subordinating conjunction in the book. I had expected “because” to be more common when compared to the other subordinating conjunctions since I personally use “because” frequently in my writing. This might suggest that there could be a different distribution of subordinating conjunctions that are more commonly used based on the writing context. Furthermore, this result does not provide any information about which subordinating conjunctions are more useful than others, particularly in the context of text prediction. Our next step is to evaluate the current large language models (LLMs) to determine their likelihood of predicting subordinating conjunctions.

Preliminary Modeling

from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn.functional import softmax, cross_entropy
from datasets import load_dataset
import pandas as pd
import numpy as np
import random
import torch
import spacy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

spacy.prefer_gpu()
nlp = spacy.load("en_core_web_sm")

Load the Model

We will use the Llama-2-7b-chat-hf model to evaluate an LLM’s likelihood of predicting subordinating conjunctions. Unfortunately, running the model is computationally expensive on most machines. Therefore, we used AutoAWQ to quantize the model into 4-bit precision2. This reduces the amount of computational resources required to run inference on the model while still maintaining a high level of accuracy. We have provided our code for quantizing the model in the Appendix. In the meantime, you can access our quantized model here: CalvinU/Llama-2-7b-chat-hf-awq.

model_name = "CalvinU/Llama-2-7b-chat-hf-awq"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

Load the Data

We have also described in the Appendix a scalable approach to collecting and processing data from Project Gutenberg. In the meantime, you can access our dataset here: CalvinU/project-gutenberg.

dataset_name = "CalvinU/project-gutenberg"

dataset = load_dataset(dataset_name, split="train")
dataset_df = pd.DataFrame(dataset)

The dataset contains 10 random books downloaded from Project Gutenberg. These books have already been sanitized and split into sentences based on their book_id and title. Therefore, each row in the dataset represents an ordered sentence from one of the books. Let’s take a look at the dataset:

dataset_df.tail()

Wrangling

Since we already have an ordered list of sentences, we can apply the same approach we used in the EDA section to identify sentences that contain subordinating conjunctions:

#| code-fold: true

def doc_pipe(sentence: str):
    pipe_disable = ["ner", "lemmatizer", "textcat"]
    return list(nlp.pipe([sentence], disable=pipe_disable))


def has_sconj(sentence: str):
    """Check if a sentence contains a subordinating conjunction

    Arg:
        sentence: The sentence to check

    Returns:
        A Pandas Series containing a boolean value indicating whether the sentence contains a subordinating conjunction and the subordinating conjunction if it exists
    """

    doc = doc_pipe(sentence)

    # Check if the sentence contains a subordinating conjunction
    for token in doc[0]:
        if token.pos_ == "SCONJ":
            return pd.Series([True, token.text])

    return pd.Series([False, None])
# Check if the sentence contains a subordinating conjunction and store the result in a new column
dataset_df[["has_sconj", "sconj"]] = dataset_df["sentence"].apply(has_sconj)

# Sanity check
assert dataset_df["has_sconj"].value_counts().sum() == len(dataset_df)

dataset_df.tail()

Number of sentences and number of sentences that contain subordinating conjunctions for each book:

summary = dataset_df.groupby(["book_id", "title"], as_index=False).agg(
    num_sents=("sentence", "count"),
    num_sconj=("has_sconj", "sum"),
)

summary.head(10)

Analysis

Suppose the general structure of a sentence with a subordinating conjunction is:

<sentence-with-SCONJ> ::= <subordinate-clause> <independent-clause> | 
                          <independent-clause> <subordinate-clause>

Note that a <subordinate-clause> is a dependent clause that contains a subordinating conjunction and cannot stand alone as a sentence, while an <independent-clause> is a main clause that can stand alone as a sentence.

In order to evaluate the likelihood of an LLM predicting subordinating conjunctions, we will investigate the following behaviors:

How does the cross-entropy and perplexity change when we provide the context exactly as it appears in the book, versus when we randomly shuffle the context?

And for each case, what is the probability spectrum at the subordinating conjunction? What is the cross-entropy and perplexity of the text after the subordinating conjunction?

When Context Is Provided Exactly as It Appears in the Book

To get started, let’s select one of the books from the dataset titled _ The Adventures of a Dog, and a Good Dog Too_ by Alfred Elwes:

#| code-fold: true

# Book ID for The Adventures of a Dog, and a Good Dog Too
book_id = 20741

selected_book = dataset_df[dataset_df["book_id"] == book_id].reset_index(drop=True)
selected_book.tail()

How many sentences are there in the book? How many of the sentences contain subordinating conjunctions?

#| echo: false

print(
    f"There are {len(selected_book)} sentences in the book. There are {len(selected_book[selected_book['has_sconj']])} sentences with a subordinating conjunction"
)

It appears that this book uses a significant number of subordinating conjunctions! Let’s choose one of the last sentences that includes a subordinating conjunction and select a maximum of 100 sentences preceding it as the context:

last_sconj_index = selected_book[selected_book["has_sconj"]].index[-3]

context = selected_book.iloc[max(last_sconj_index - 100, 0) : last_sconj_index][
    "sentence"
].tolist()

context = " ".join(context)

sentence = selected_book.iloc[last_sconj_index]["sentence"]

Let’s take a small peek at the context:

#| echo: false

# Print the first 50 characters of the context
print(context[:100].strip() + " ... ", end="")

# Print the last 50 characters of the context
print(context[-100:].strip(), end="\n\n")

Let’s take a look at the sentence with the subordinating conjunction:

#| echo: false

# Print the sentence
print(sentence)

Let’s tokenize the context and the sentence, and then feed them into the model to get the predicted logits:

#| code-fold: true

# Tokenize the context (to be used later, not as an input sequence)
context_tokenized = tokenizer(context, return_tensors="pt").to(device)
context_input_ids = context_tokenized.input_ids

# Tokenize the sentence (to be used later, not as an input sequence)
sentence_tokenized = tokenizer(sentence, return_tensors="pt").to(device)
sentence_input_ids = sentence_tokenized.input_ids

# Tokenize the context and the sentence as an input sequence
prompt_tokenized = tokenizer(context + sentence, return_tensors="pt").to(device)
prompt_input_ids = prompt_tokenized.input_ids

# Get the predicted logits for the input sequence
model_logits = model(prompt_input_ids).logits
#| code-fold: true

# Get the subordinating conjunction from the book
sconj = selected_book.iloc[last_sconj_index]["sconj"]

# Decode the context as a string, excluding the first token
context_decoded = [tokenizer.decode(token) for token in context_input_ids[0]][1:]

# Decode the sentence as a string, excluding the first token
sentence_decoded = [tokenizer.decode(token) for token in sentence_input_ids[0]][1:]

# Index of the subordinating conjunction token in the input sequence
sconj_token_index = len(context_decoded) + sentence_decoded.index(sconj)

# Index of the subordinating conjunction in the sentence (not the input sequence, but from the book)
sconj_index = selected_book.iloc[last_sconj_index]["sentence"].find(sconj)

# Figure out which type of clause comes first
if sconj not in sentence[:sconj_index]:
    independent_clause = sentence[:sconj_index]
    subordinate_clause = sentence[sconj_index:]
else:
    independent_clause = sentence[sconj_index:]
    subordinate_clause = sentence[:sconj_index]

# Index of the independent clause in the sentence (not the input sequence, but from the book)
independent_clause_index = selected_book.iloc[last_sconj_index]["sentence"].find(
    independent_clause
)

# Index of the subordinate clause in the sentence (not the input sequence, but from the book)
subordinate_clause_index = selected_book.iloc[last_sconj_index]["sentence"].find(
    subordinate_clause
)

# Tokenize the independent clause
independent_clause_tokenized = tokenizer(
    independent_clause, return_tensors="pt"
).to(device)

# Tokenize the subordinate clause
subordinate_clause_tokenized = tokenizer(
    subordinate_clause, return_tensors="pt"
).to(device)

Given that we have fed in the context exactly as it appears in the book, let’s take a look at the top k probability spectrum at the subordinating conjunction:

#| code-fold: true

def probability_spectrum_at(logits, input_ids, i, k=6):
    """Given an input sequence, get the top k probability spectrum at the given index

    Args:
        logits: predicted logits for the input sequence
        input_ids: input sequence token IDs
        i: index to get the probability spectrum at
        k: top k, default is 6

    Returns:
        A Pandas DataFrame containing the top k probability spectrum at the given index
    """

    # Predicted logits for an input sequence, excluding the last element
    adjusted_logits = logits[0, :-1]

    # Input sequence, starting from the second element
    adjusted_input_ids = input_ids[0, 1:]

    # Get the probability distribution predicted by the model
    probability_distribution = softmax(adjusted_logits[i], dim=0)

    # Get the top k probabilities and their respective indices, default k=6
    top_probability_distribution, top_indices = probability_distribution.topk(k)

    # Get the top k probability spectrum as a DataFrame
    probability_spectrum = pd.DataFrame(
        {
            "token": [tokenizer.decode(token) for token in top_indices.tolist()],
            "probability": top_probability_distribution.tolist(),
        }
    )

    # Decode the input sequence as a string
    matching_token = tokenizer.decode(adjusted_input_ids[i])

    # Highlight the matching string in the probability spectrum
    def highlight_prompt_at(x):
        if x["token"] == matching_token:
            return ["background-color: #6495ED"] * len(x)
        else:
            return [""] * len(x)

    return probability_spectrum.style.apply(highlight_prompt_at, axis=1)


def cross_entropy_at(logits, input_ids, i):
    """Given an input sequence, get the cross entropy at the given index

    Args:
        logits: predicted logits for the input sequence
        input_ids: input sequence token IDs
        i: index to get the cross entropy at

    Returns:
        The cross entropy at the given index
    """

    # Predicted logits for an input sequence, excluding the last element
    adjusted_logits = logits[0, :-1]

    # Input sequence, starting from the second element
    adjusted_input_ids = input_ids[0, 1:]

    # Get the cross entropy per input sequence
    cross_entropy_seq = cross_entropy(
        adjusted_logits, adjusted_input_ids, reduction="none"
    )

    return cross_entropy_seq[i].item()


def cross_entropy_per_token(logits, input_ids, matching_sequence_tokenized):
    """Given a matching sequence, get the cross entropy for each token in the matching sequence

    Args:
        logits: predicted logits for the input sequence
        input_ids: input sequence token IDs
        matching_sequence_tokenized: tokenized matching sequence

    Returns:
        A Pandas DataFrame containing the cross entropy for each token in the matching sequence
    """

    # Predicted logits for an input sequence, excluding the last element
    adjusted_logits = logits[0, :-1]

    # Input sequence, starting from the second element
    adjusted_input_ids = input_ids[0, 1:]

    # Get the cross entropy per input sequence
    cross_entropy_seq = cross_entropy(
        adjusted_logits, adjusted_input_ids, reduction="none"
    )

    # Decode the tokenized matching sequence as a string
    matching_sequence_token = [
        tokenizer.decode(token) for token in matching_sequence_tokenized.input_ids[0]
    ]

    # Decoded matching sequence token, starting from the second element
    adjusted_matching_sequence_token = matching_sequence_token[1:]

    return pd.DataFrame(
        {
            "token": adjusted_matching_sequence_token,
            "cross_entropy": cross_entropy_seq[
                -len(adjusted_matching_sequence_token) :
            ].tolist(),
        }
    )
#| code-fold: true

prob_spectrum = probability_spectrum_at(
    model_logits, prompt_input_ids, sconj_token_index
)

prob_spectrum

Let’s also look at the cross-entropy and perplexity of the sentence with the subordinating conjunction:

#| code-fold: true

# Cross-entropy of the sentence with the subordinating conjunction (entire input sequence)
sentence_per_token_cross_entropy = cross_entropy_per_token(
    model_logits, prompt_input_ids, sentence_tokenized
)

# Mean cross-entropy of the sentence with the subordinating conjunction (entire input sequence)
mean_sentence_cross_entropy = sentence_per_token_cross_entropy["cross_entropy"].mean()

# Perplexity of the sentence with the subordinating conjunction (entire input sequence)
sentence_perplexity = np.exp(mean_sentence_cross_entropy)

# Cross-entropy of the independent clause
independent_clause_per_token_cross_entropy = cross_entropy_per_token(
    model_logits, prompt_input_ids, independent_clause_tokenized
)

# Mean cross-entropy of the independent clause
mean_independent_clause_cross_entropy = independent_clause_per_token_cross_entropy[
    "cross_entropy"
].mean()

# Perplexity of the independent clause
independent_clause_perplexity = np.exp(mean_independent_clause_cross_entropy)

# Cross-entropy of the subordinate clause
subordinate_clause_per_token_cross_entropy = cross_entropy_per_token(
    model_logits, prompt_input_ids, subordinate_clause_tokenized
)

# Mean cross-entropy of the subordinate clause
mean_subordinate_clause_cross_entropy = subordinate_clause_per_token_cross_entropy[
    "cross_entropy"
].mean()

# Perplexity of the subordinate clause
subordinate_clause_perplexity = np.exp(mean_subordinate_clause_cross_entropy)

# Cross-entropy of the subordinating conjunction
subordinating_conjunction_cross_entropy = cross_entropy_at(
    model_logits, prompt_input_ids, sconj_token_index
)

# Perplexity of the subordinating conjunction
subordinating_conjunction_perplexity = np.exp(subordinating_conjunction_cross_entropy)
#| echo: false

# Print in the order of the sentence structure
if independent_clause_index > subordinate_clause_index:
    print("Structure:")
    print("\t<sentence-with-SCONJ> ::= <subordinate-clause> <independent-clause>")
    print("\nMetrics:")
    print(f"\t<sentence-with-SCONJ> cross-entropy: {mean_sentence_cross_entropy}")
    print(f"\t<sentence-with-SCONJ> perplexity:    {sentence_perplexity}\n")
    print(f"\t<subordinate-clause> cross-entropy:  {mean_subordinate_clause_cross_entropy}")
    print(f"\t<subordinate-clause> perplexity:     {subordinate_clause_perplexity}")
    print(f"\t<SCONJ> cross-entropy:               {subordinating_conjunction_cross_entropy}")
    print(f"\t<SCONJ> perplexity:                  {subordinating_conjunction_perplexity}")
    print(f"\t<independent-clause> cross-entropy:  {mean_independent_clause_cross_entropy}")
    print(f"\t<independent-clause> perplexity:     {independent_clause_perplexity}")
else:
    print("Structure:")
    print("\t<sentence-with-SCONJ> ::= <independent-clause> <subordinate-clause>")
    print("\nMetrics:")
    print(f"\t<sentence-with-SCONJ> cross-entropy: {mean_sentence_cross_entropy}")
    print(f"\t<sentence-with-SCONJ> perplexity:    {sentence_perplexity}\n")
    print(f"\t<independent-clause> cross-entropy:  {mean_independent_clause_cross_entropy}")
    print(f"\t<independent-clause> perplexity:     {independent_clause_perplexity}")
    print(f"\t<SCONJ> cross-entropy:               {subordinating_conjunction_cross_entropy}")
    print(f"\t<SCONJ> perplexity:                  {subordinating_conjunction_perplexity}")
    print(f"\t<subordinate-clause> cross-entropy:  {mean_subordinate_clause_cross_entropy}")
    print(f"\t<subordinate-clause> perplexity:     {subordinate_clause_perplexity}")

When Context Is Randomly Shuffled

Let’s shuffle the context and feed it into the model to get the predicted logits:

last_sconj_index = selected_book[selected_book["has_sconj"]].index[-3]

context = selected_book.iloc[max(last_sconj_index - 100, 0) : last_sconj_index][
    "sentence"
].tolist()

random.Random(42).shuffle(context)

context = " ".join(context)

sentence = selected_book.iloc[last_sconj_index]["sentence"]

Let’s take a small peek at the context:

#| echo: false

# Print the first 50 characters of the context
print(context[:100].strip() + " ... ", end="")

# Print the last 50 characters of the context
print(context[-100:].strip(), end="\n\n")

Let’s take a look at the sentence with the subordinating conjunction:

#| echo: false

# Print the sentence
print(sentence)

Same steps as before:

#| code-fold: true

# Tokenize the context (to be used later, not as an input sequence)
context_tokenized = tokenizer(context, return_tensors="pt").to(device)
context_input_ids = context_tokenized.input_ids

# Tokenize the sentence (to be used later, not as an input sequence)
sentence_tokenized = tokenizer(sentence, return_tensors="pt").to(device)
sentence_input_ids = sentence_tokenized.input_ids

# Tokenize the context and the sentence as an input sequence
prompt_tokenized = tokenizer(context + sentence, return_tensors="pt").to(device)
prompt_input_ids = prompt_tokenized.input_ids

# Get the predicted logits for the input sequence
model_logits = model(prompt_input_ids).logits
#| code-fold: true

# Get the subordinating conjunction from the book
sconj = selected_book.iloc[last_sconj_index]["sconj"]

# Decode the context as a string, excluding the first token
context_decoded = [tokenizer.decode(token) for token in context_input_ids[0]][1:]

# Decode the sentence as a string, excluding the first token
sentence_decoded = [tokenizer.decode(token) for token in sentence_input_ids[0]][1:]

# Index of the subordinating conjunction token in the input sequence
sconj_token_index = len(context_decoded) + sentence_decoded.index(sconj)

# Index of the subordinating conjunction in the sentence (not the input sequence, but from the book)
sconj_index = selected_book.iloc[last_sconj_index]["sentence"].find(sconj)

# Figure out which type of clause comes first
if sconj not in sentence[:sconj_index]:
    independent_clause = sentence[:sconj_index]
    subordinate_clause = sentence[sconj_index:]
else:
    independent_clause = sentence[sconj_index:]
    subordinate_clause = sentence[:sconj_index]

# Index of the independent clause in the sentence (not the input sequence, but from the book)
independent_clause_index = selected_book.iloc[last_sconj_index]["sentence"].find(
    independent_clause
)

# Index of the subordinate clause in the sentence (not the input sequence, but from the book)
subordinate_clause_index = selected_book.iloc[last_sconj_index]["sentence"].find(
    subordinate_clause
)

# Tokenize the independent clause
independent_clause_tokenized = tokenizer(
    independent_clause, return_tensors="pt"
).to(device)

# Tokenize the subordinate clause
subordinate_clause_tokenized = tokenizer(
    subordinate_clause, return_tensors="pt"
).to(device)

Given that we have fed in a randomly shuffled context, let’s take a look at the top k probability spectrum at the subordinating conjunction:

prob_spectrum = probability_spectrum_at(
    model_logits, prompt_input_ids, sconj_token_index
)

prob_spectrum

Let’s also look at the cross-entropy and perplexity of the sentence with the subordinating conjunction:

#| code-fold: true

# Cross-entropy of the sentence with the subordinating conjunction (entire input sequence)
sentence_per_token_cross_entropy = cross_entropy_per_token(
    model_logits, prompt_input_ids, sentence_tokenized
)

# Mean cross-entropy of the sentence with the subordinating conjunction (entire input sequence)
mean_sentence_cross_entropy = sentence_per_token_cross_entropy["cross_entropy"].mean()

# Perplexity of the sentence with the subordinating conjunction (entire input sequence)
sentence_perplexity = np.exp(mean_sentence_cross_entropy)

# Cross-entropy of the independent clause
independent_clause_per_token_cross_entropy = cross_entropy_per_token(
    model_logits, prompt_input_ids, independent_clause_tokenized
)

# Mean cross-entropy of the independent clause
mean_independent_clause_cross_entropy = independent_clause_per_token_cross_entropy[
    "cross_entropy"
].mean()

# Perplexity of the independent clause
independent_clause_perplexity = np.exp(mean_independent_clause_cross_entropy)

# Cross-entropy of the subordinate clause
subordinate_clause_per_token_cross_entropy = cross_entropy_per_token(
    model_logits, prompt_input_ids, subordinate_clause_tokenized
)

# Mean cross-entropy of the subordinate clause
mean_subordinate_clause_cross_entropy = subordinate_clause_per_token_cross_entropy[
    "cross_entropy"
].mean()

# Perplexity of the subordinate clause
subordinate_clause_perplexity = np.exp(mean_subordinate_clause_cross_entropy)

# Cross-entropy of the subordinating conjunction
subordinating_conjunction_cross_entropy = cross_entropy_at(
    model_logits, prompt_input_ids, sconj_token_index
)

# Perplexity of the subordinating conjunction
subordinating_conjunction_perplexity = np.exp(subordinating_conjunction_cross_entropy)
#| echo: false

# Print in the order of the sentence structure
if independent_clause_index > subordinate_clause_index:
    print("Structure:")
    print("\t<sentence-with-SCONJ> ::= <subordinate-clause> <independent-clause>")
    print("\nMetrics:")
    print(f"\t<sentence-with-SCONJ> cross-entropy: {mean_sentence_cross_entropy}")
    print(f"\t<sentence-with-SCONJ> perplexity:    {sentence_perplexity}\n")
    print(f"\t<subordinate-clause> cross-entropy:  {mean_subordinate_clause_cross_entropy}")
    print(f"\t<subordinate-clause> perplexity:     {subordinate_clause_perplexity}")
    print(f"\t<SCONJ> cross-entropy:               {subordinating_conjunction_cross_entropy}")
    print(f"\t<SCONJ> perplexity:                  {subordinating_conjunction_perplexity}")
    print(f"\t<independent-clause> cross-entropy:  {mean_independent_clause_cross_entropy}")
    print(f"\t<independent-clause> perplexity:     {independent_clause_perplexity}")
else:
    print("Structure:")
    print("\t<sentence-with-SCONJ> ::= <independent-clause> <subordinate-clause>")
    print("\nMetrics:")
    print(f"\t<sentence-with-SCONJ> cross-entropy: {mean_sentence_cross_entropy}")
    print(f"\t<sentence-with-SCONJ> perplexity:    {sentence_perplexity}\n")
    print(f"\t<independent-clause> cross-entropy:  {mean_independent_clause_cross_entropy}")
    print(f"\t<independent-clause> perplexity:     {independent_clause_perplexity}")
    print(f"\t<SCONJ> cross-entropy:               {subordinating_conjunction_cross_entropy}")
    print(f"\t<SCONJ> perplexity:                  {subordinating_conjunction_perplexity}")
    print(f"\t<subordinate-clause> cross-entropy:  {mean_subordinate_clause_cross_entropy}")
    print(f"\t<subordinate-clause> perplexity:     {subordinate_clause_perplexity}")

Results and Conclusion

Our analysis section demonstrates that the cross-entropy and perplexity of the sentence with the subordinating conjunction change based on the context provided to the model. Furthermore, we observed that the probability of the subordinating conjunction is also affected by the context. This suggests that the context provided to the model is important for predicting subordinating conjunctions. In other words, the context provided to the model can influence the likelihood of the model predicting subordinating conjunctions. Moreover, we have also observed that, despite the change in the context, the cross-entropy and perplexity around the subordinate clause did not change as much as around the independent clause. Although this warrants more thorough investigation, it suggests that there is a certain kind of subordinating conjunction that appears to be more useful, even to the LLM (as it was still likely to construct the same subordinate clause even with the contextual change).

Limitations

Our work is not without limitations. Firstly, we have only analyzed the LLM with one book, which is not representative of different kinds of writing contexts. Furthermore, our approach is currently only able to parse subordinate clauses that position the subordinating conjunction in the middle of the sentence. There are certain edge cases related to the positioning of the subordinating conjunctions that we have not considered.

Future Work

A natural extension of this work is to evaluate the LLM’s likelihood of predicting subordinating conjunctions with a more diverse and representative sample of data. Furthermore, we can also evaluate the LLM’s likelihood of predicting other kinds of conjunctions, such as coordinating conjunctions. Moreover, we can also evaluate the LLM’s likelihood of predicting subordinating conjunctions in different kinds of writing contexts, such as academic writing instead of books. Another way to extend this work is to evaluate the relationship between the LLM’s hyperparameters and its likelihood of predicting subordinating conjunctions.

Appendix

AutoAWQ Quantization

In this section, we have documented our approach to quantizing the Llama-2-7b-chat-hf model using AutoAWQ into 4-bit precision. This reduces the amount of computational resources required to run inference on the model while still maintaining a high level of accuracy.

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, AwqConfig
model_name = "meta-llama/Llama-2-7b-chat-hf"
quantized_model_path = "Llama-2-7b-chat-hf-awq"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoAWQForCausalLM.from_pretrained(model_name, **{"low_cpu_mem_usage": True})
# Setup AutoAWQ quantization configuration
quant_config = {
    "zero_point": True,
    "q_group_size": 128,
    "w_bit": 4,
    "version": "GEMM",
}

# Quantize the model
model.quantize(tokenizer, quant_config=quant_config)
# Setup Transformer compatible quantization configuration
quantization_config = AwqConfig(
    bits=quant_config["w_bit"],
    group_size=quant_config["q_group_size"],
    zero_point=quant_config["zero_point"],
    version=quant_config["version"].lower(),
).to_dict()

# Pass the new quantization configuration to the model
model.model.config.quantization_config = quantization_config

# Save the quantized model weights
tokenizer.save_pretrained(quantized_model_path)
model.save_quantized(quantized_model_path)

To promote reproducibility of this work, we have uploaded our quantized model to Hugging Face repositories. You can access our quantized model here: CalvinU/Llama-2-7b-chat-hf-awq.

Scalable Data Collection

In the EDA, we have only looked at one book. However, in a language modeling task, we would likely need a sample of data that is diverse and representative of different kinds of writing. In this section, we have documented a scalable approach to collecting and processing data from Project Gutenberg.

import pandas as pd
import requests
import random
import spacy

nlp = spacy.load("en_core_web_sm")

Here are the functions we have used. All of them were already defined in the EDA section, except for download_books, which is a wrapper function for download_book that downloads multiple books instead of just one:

Code
def download_book(book_id: int) -> tuple[str, str]:
    """Download a book from Project Gutenberg

    Arg:
        book_id: The Project Gutenberg ID of the book to download

    Returns:
        A tuple containing the book title and the book text
    """

    gutendex_url = f"https://gutendex.com/books/{book_id}/"

    try:
        response = requests.get(gutendex_url)
        response.raise_for_status()
        data = response.json()

        book_language = data["languages"]

        # Only download books in English
        if "en" in book_language:
            book_title = data["title"]

            # Only download books in plain text
            mime_types = ["text/plain", "text/plain; charset=us-ascii"]

            for mime_type in mime_types:
                if mime_type in data["formats"]:
                    book_url = data["formats"][mime_type]
                    break

            if book_url is None:
                raise Exception("The book is not available in plain text.")

            response = requests.get(book_url)
            response.raise_for_status()

            return book_title, response.text
        else:
            raise Exception("The book is not in English.")
    except requests.exceptions.HTTPError as err:
        raise Exception(err)


def download_books(n: int) -> list[tuple[int, str, str]]:
    """Download n books from Project Gutenberg

    Arg:
        n: The number of books to download

    Returns:
        A list of downloaded books
    """

    max_book_count = requests.get("https://gutendex.com/books/").json()["count"]

    books = []

    i = 0
    while i < n:
        book_id = random.randint(1, max_book_count)

        try:
            book_title, book_text = download_book(book_id)
            books.append((book_id, book_title, book_text))
            i += 1
        except Exception as e:
            continue

    return books


def sanitize_text(text: str) -> str:
    """Remove extra information from the text

    Arg:
        text: The text to sanitize

    Returns:
        The sanitized text
    """

    start_marker = "***"
    end_marker = "*** END OF THE PROJECT GUTENBERG EBOOK"

    # Index of the second occurrence of the start marker
    start_index = text.find(start_marker, text.find(start_marker) + 1)

    # Index of the first occurrence of the end marker
    end_index = text.find(end_marker)

    # Remove the extra information based on the marker indices
    if start_index != -1 and end_index != -1:
        text = text[start_index + len(start_marker) : end_index].strip()

    return text


def sentence_spliter(text: str) -> list[str]:
    """Split the text into sentences

    Arg:
        text: The text to split

    Returns:
        A list of sentences
    """

    nlp.max_length = len(text)

    pipe_disable = ["ner", "lemmatizer", "textcat"]

    # Remove line breaks and split the text into sentences
    doc = nlp.pipe([text.replace("\r\n", " ")], disable=pipe_disable)

    # Return a list of sentences without leading and trailing whitespace
    return [sent.text.strip() for doc in doc for sent in doc.sents]

Download 10 random books from Project Gutenberg:

n_books = 10

books10 = pd.DataFrame(
    download_books(n_books), 
    columns=["book_id", "title", "text"]
)

assert len(books10) == n_books

Clean the texts:

books10["clean_text"] = books10["text"].apply(sanitize_text)

Split the texts into sentences:

books10_sentences = []

# For each book, split the text into sentences
for i in range(0, len(books10)):
    books10_sentences.append(
        (
            books10["book_id"].iloc[i],
            books10["title"].iloc[i],
            sentence_spliter(books10["clean_text"].iloc[i]),
        )
    )

Create a new DataFrame with the sentences:

# For each sentences in each id, create a new row
books10_sentences = [
    (id, title, sent) for id, title, sents in books10_sentences for sent in sents
]

books10_sentences = pd.DataFrame(
    books10_sentences, columns=["book_id", "title", "sentence"]
)

To promote reproducibility of this work, we have saved the data we have collected and processed using this approach as a parquet file. You can view and access our dataset here: CalvinU/project-gutenberg.

Back to top

References

[1]
Kenneth C. Arnold, Krysta Chauncey, and Krzysztof Z. Gajos. 2018. Sentiment bias in predictive text recommendations results in biased writing. In Proceedings of the 44th graphics interface conference (GI ’18), 2018. Association for Computing Machinery, Toronto, Ontario, Canada, 42–49. Retrieved from https://doi.org/10.20380/GI2018.07
[2]
Kenneth C. Arnold, Krysta Chauncey, and Krzysztof Z. Gajos. 2020. Predictive text encourages predictable writing. In Proceedings of the 25th international conference on intelligent user interfaces (IUI ’20), 2020. Association for Computing Machinery, New York, NY, USA, 128–138. Retrieved from https://doi.org/10.1145/3377325.3377523
[3]
Linda Flower and John R. Hayes. 1981. A cognitive process theory of writing. College Composition and Communication 32, 4 (1981), 365–387.
[4]
Maurice Jakesch, Advait Bhat, Daniel Buschek, Lior Zalmanson, and Mor Naaman. 2023. Co-writing with opinionated language models affects users’ views. In Proceedings of the 2023 CHI conference on human factors in computing systems (CHI ’23), 2023. Association for Computing Machinery, New York, NY, USA, 1–15. Retrieved from https://doi.org/10.1145/3544548.3581196
[5]
Hyechan Jun, Ha-Ram Koo, and Advait Scaria. 2021. DAISI: The deep artificial intelligence system for interviews. 2021. Retrieved from https://haramkoo.github.io/InterviewAI/

Footnotes

  1. https://www.merriam-webster.com/dictionary/inspire↩︎

  2. Chosen for its simplicity, however, other quantization methods, such as llama.cpp, will also likely work.↩︎