This lab is optional and is intended to expose you to a few basic implementations of reccurrent neural networks in PyTorch.
The sequential data used throughout the labs are surnames (last names) of samples of individuals from six different nationalities.
To begin, you'll need to import the following libraries:
import torch
import torchvision
import string
Next, you'll need to download and extract the data in the zipped folder below:
This folder contains 6 different text files, one per nationality, that each contain the unique surnames observed among the members of that nationality (with one name per line).
You should be able to modify the root path and use the code given below to load these textfiles into Python lists:
root = 'C:/Users/millerry/OneDrive - Grinnell College/Documents/surnames/'
Chinese = open(root+'Chinese.txt', encoding='utf-8').read().strip().split('\n')
Japanese = open(root+'Japanese.txt', encoding='utf-8').read().strip().split('\n')
Korean = open(root+'Korean.txt', encoding='utf-8').read().strip().split('\n')
English = open(root+'English.txt', encoding='utf-8').read().strip().split('\n')
Irish = open(root+'Irish.txt', encoding='utf-8').read().strip().split('\n')
Russian = open(root+'Russian.txt', encoding='utf-8').read().strip().split('\n')
Recurrent neural networks are designed to work sequential data, and our models throughout the lab will consider each character in a name as a sequential observation. This framework requires us to represent the individual characters in a name using a one-hot vectors.
To facilitate this process, we'll start by defining a helper function that converts a single line of text (ie: a name) into a tensor comprised of one-hot vectors representing each character.
## We'll consider all ascii letters plus basic punctuation
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
## Function to iterate through a line of text encode each letter as a 1 x 57 vector in an nchar x 1 x 57 tensor
def nameToTensor(line):
tensor = torch.zeros(len(line), 1, n_letters)
for li, letter in enumerate(line):
tensor[li][0][all_letters.find(letter)] = 1
return tensor
The code below demonstrates the behavior of this function on a simple example, the line "Aa". Notice how the output is a tensor with dimensions [2, 1, 57].
## Demonstration of the test name "Aa", notice the "A" is encoded as the 27th position, and "a" is the 1st position
example = nameToTensor('Aa')
print(example)
## Notice the dimensions of the output and where 1's appear
print(example.size())
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]) torch.Size([2, 1, 57])
Because data preparation is a crucial step in using models like RNNs, the question below will check your understanding of the tensors produced by nameToTensor()
:
Question 1:
nameToTensor()
function? Will the size/length of this dimension change if a different input text is used?nameToTensor()
function? Will the size/length of this dimension change if a different input text is provided?Next, we'll define a simple network architecture to model our sequential data. You should notice that this network architecture is flexible enough to handle inputs of different sizes (since each surname is a different length).
from torch import nn
class my_rnn(nn.Module):
## Constructor commands
def __init__(self, input_size, hidden_size, output_size):
super(my_rnn, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
## Function to generate predictions
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(hidden)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
To understand this architecture, it's easiest to look at the forward
method:
combined
, by concatenating them along dim = 1
.i2h
, which produces the next hidden state.i2o
which produces an ouput.In this simple architecture the only thing we might consider manipulating is the size of the hidden state (since the input size is out of our control). Increasing the hidden size will provide the model with more flexibility to learn sequential patterns that exist within the training sequences (names), but too much flexibility could lead to overfitting.
To explore how this model works, let's initialize it with randomly generated weights and see what it outputs for an example name:
## Initialize model with random weights
h_size = 100
rnn = my_rnn(n_letters, h_size, 6)
## Format an example input name (Albert)
test_input = nameToTensor('Albert')
## Provide an initial hidden state (all zeros as an example)
hidden = torch.zeros(1, h_size)
## Generate output from the RNN
output, next_hidden = rnn(test_input[0], hidden)
print(output)
print(next_hidden)
tensor([[-1.7671, -1.7803, -1.8638, -1.7610, -1.7642, -1.8184]], grad_fn=<LogSoftmaxBackward0>) tensor([[ 0.0279, 0.0997, 0.0669, -0.0740, 0.0734, -0.0247, 0.0918, -0.0393, -0.0519, -0.0655, 0.0390, 0.0533, 0.0768, -0.0368, 0.1451, -0.1040, 0.0086, -0.0298, -0.0467, 0.0262, -0.0342, 0.1191, 0.0169, -0.0956, -0.0450, -0.0689, -0.1303, -0.0004, -0.0771, 0.0377, -0.0166, 0.0728, 0.1355, -0.0284, 0.0032, 0.0716, -0.0305, -0.0089, 0.0056, 0.0286, 0.1051, -0.0363, 0.0083, 0.0122, -0.0552, 0.0127, -0.1028, 0.0086, -0.0755, -0.0582, -0.0622, -0.0276, 0.0176, 0.0867, -0.0195, 0.0708, 0.1206, -0.0892, 0.1336, -0.0294, 0.1311, -0.0191, -0.1166, 0.0508, -0.0424, 0.1239, 0.0770, -0.0385, -0.0518, 0.0393, -0.0180, -0.0700, 0.0412, -0.0844, 0.0238, 0.0569, 0.0029, -0.0004, 0.0342, 0.0089, -0.0638, 0.0079, 0.0153, 0.1120, 0.0331, -0.0466, 0.1128, 0.1140, -0.1246, 0.0625, -0.0129, 0.0414, 0.1342, -0.0621, -0.1112, -0.0609, -0.0932, 0.0053, -0.0452, 0.0416]], grad_fn=<AddmmBackward0>)
For illustration purposes, we can look at the predicted class of this output:
## Print the top category (predicted class)
output.topk(1)
torch.return_types.topk( values=tensor([[-1.7610]], grad_fn=<TopkBackward0>), indices=tensor([[3]]))
Question 2:
my_rnn
was initialized with a value of 6 for output_size
. Where did this value come from? Is it something you can change when tuning the network's architecture? Briefly explain.my_rnn
was initialized with a value of 100 for hidden_size
. Where did this value come from? Is it something you can change when tuning the network's architecture? Briefly explain.To facilitate the training of our model, we'll define the category labels and a dictionary that links our lists of names to each label:
## List of categories
category_labels = ['Chinese', 'Japanese', 'Korean', 'English', 'Irish', 'Russian']
## Dictionary of categories and names
category_lines = {'Chinese': Chinese,
'Japanese': Japanese,
'Korean': Korean,
'English': English,
'Irish': Irish,
'Russian': Russian}
Next, we'll train our network by feeding it randomly selected names then updating the network's weights and biases using back-propogation.
The function defined below will facilitate the selection of randomly chosen input names during model training:
## Function to randomly sample a single example
import random
def randomTrainingExample():
## Randomly choose a category index (ie: Chinese, etc.)
category = category_labels[random.randint(0, len(category_labels)-1)]
## Randomly choose a name in that category
name = category_lines[category][random.randint(0, len(category_lines[category])-1)]
## Convert the chosen example to a tensor
category_tensor = torch.tensor([category_labels.index(category)], dtype=torch.long)
line_tensor = nameToTensor(name)
return category, name, category_tensor, line_tensor
## Try it out
randomTrainingExample()
('Russian', 'Zhvanetsky', tensor([5]), tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]))
Now we'll set up a function, train()
, that updates the network's parameters after being given a randomly selected training example:
## Set learning rate
learning_rate = 0.005
## Define cost func
cost_fn = nn.CrossEntropyLoss()
## Training function for a single input (name category, name)
def train(category_tensor, line_tensor):
## initialize the hidden state
hidden = rnn.initHidden()
## set the gradient to zero
rnn.zero_grad()
## loop through the letters in the input, getting a prediction and new hidden state each time
for i in range(line_tensor.size()[0]):
output, hidden = rnn(line_tensor[i], hidden)
## Calculate cost and gradients
cost = cost_fn(output, category_tensor)
cost.backward()
# Update parameters
for p in rnn.parameters():
p.data.add_(p.grad.data, alpha = -learning_rate) ## This adds the LR times the gradient to each parameter
## Return the output and cost
return output, cost.item()
Question 3: The train()
function defined above involves a for loop that iterates through the first dimension of line_tensor
(which is the tensor storing the input name).
hidden
is initialized (ie: reset) every time train()
is called on a new training example. Why is it important to do this when training the model?We are now read to train the model. To do so, we'll use 10,000 randomly selected names and record the model's cost every 25 iterations.
## Initializations
n_iters = 10000
cost_every_n = 25
current_cost = 0
track_cost = []
### Iteratively update model from randomly chosen example
for iter in range(1, n_iters + 1):
category, line, category_tensor, line_tensor = randomTrainingExample()
output, cost = train(category_tensor, line_tensor)
current_cost += cost
# Save cost every 25 iterations
if iter % cost_every_n == 0:
track_cost.append(current_cost/cost_every_n)
current_cost = 0
Next, we'll graph the costs throughout the training process to see if our model has learned anything from our training examples:
import matplotlib.pyplot as plt
plt.plot(track_cost)
plt.show()
We can see that the model appears to have learned some patterns contained in surnames, but there's quite a bit of variability in the cost observed in one set of 25 names to another.
The RNN we built and trained is designed to predict the labels of input sequences of characters. This means that we can give the trained model any valid sequence of characters and it will predict the nationality it believes that name belongs to.
To see this in action, we'll create a predict
function that returns the top N predicted labels (and their associated outputs) for a given input name:
def predict(input_line, n_predictions=4):
print('\n> %s' % input_line)
## Don't update gradient with any of these examples
with torch.no_grad():
## Initialize new hidden state
hidden = rnn.initHidden()
## Convert input str to tensor
input_t = nameToTensor(input_line)
## Pass each character into `rnn`
for i in range(input_t.size()[0]):
output, hidden = rnn(input_t[i], hidden)
# Get top N categories from output
topv, topi = output.topk(n_predictions, 1, True)
predictions = []
## Go through the category predictions and save info for printing
for i in range(n_predictions):
value = topv[0][i].item()
category_index = topi[0][i].item()
print('(%.2f) %s' % (value, category_labels[category_index]))
predictions.append([value, category_labels[category_index]])
## Try it out on a few examples:
predict('Dovesky')
predict('Miller')
predict('Satoshi')
predict('ABCDEFGHIJKLMNOP')
> Dovesky (-0.21) Russian (-2.11) English (-3.35) Japanese (-3.49) Irish > Miller (-0.45) English (-1.70) Irish (-1.78) Russian (-4.28) Japanese > Satoshi (-0.04) Japanese (-3.78) Russian (-4.80) Irish (-5.52) English > ABCDEFGHIJKLMNOP (-0.66) Russian (-1.26) English (-2.11) Japanese (-3.40) Irish
Question 4:
predict
function on 1 or 2 names of your choosing. Include your code and output, and write 1-2 sentences reflecting upon whether you satisfied/surprised by the results.hidden_size
of the network and retrain it using more than 10,000 training iterations. Try out the name names you considerd in Part A and briefly reflect upon whether the results appear to be better, worse, or roughly the same.This section provides a brief illustration of a simple generative RNN. The network will be trained using the same surnames data that we've been working with, and it will be set up to generate a predicted name when given a initial string of characters.
For our previous model, we prepared our data using one-hot encoding to represent each unique letter. This time, we'll add an extra position that does not correspond to any letter to function as a "stop character", which will stop the model from continuing to generate new characters:
## Set up number of category labels and number of letters (plus 1 for a stop char)
n_categories = len(category_labels)
all_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1
Next we'll define the model's architecture, which is somewhat more complicated than our previous example. This added complexity should allow it to learn more patterns from the training data.
from torch import nn
class my_gen_rnn(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(my_gen_rnn, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)
self.o2o = nn.Linear(hidden_size + output_size, output_size)
self.dropout = nn.Dropout(0.1)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, category, input, hidden):
input_combined = torch.cat((category, input, hidden), 1)
hidden = self.i2h(input_combined)
output = self.i2o(input_combined)
output_combined = torch.cat((hidden, output), 1)
output = self.o2o(output_combined)
output = self.dropout(output)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
This model involves three key components that are explained below:
i2h
takes a combined input tensor containing the category, input string, and current hidden state and outputs a new hidden statei2o
takes the same combined input as i2h
, but it produces an intermediate output that will ultimately contribute to a new predicted charactero2o
is an extra layer that takes the combined output of i2h
and o2o
to generate a predicted character.The recurrent structure of the network can be more easily understood using the diagram below:
from IPython.display import HTML
HTML('<img src="https://i.imgur.com/jzVrf7f.png">')
For each training example we'll need a set of input letters (the complete surname), a set of output letters (the surname offset by 1), and the category label (nationality).
For example, if the name is "Kasparaov", the input letters would be a one-hot representation of the letters in "Kasparaov", the output letters would be a one-hot representation of "asparaov
The functions defined below will create the input letters, output letters, and category tensor for a given name:
def inputTensor(line):
tensor = torch.zeros(len(line), 1, n_letters)
for li in range(len(line)):
letter = line[li]
tensor[li][0][all_letters.find(letter)] = 1
return tensor
def outputTensor(line):
letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]
letter_indexes.append(n_letters - 1)
return torch.LongTensor(letter_indexes)
def categoryTensor(category):
li = category_labels.index(category)
tensor = torch.zeros(1, n_categories)
tensor[0][li] = 1
return tensor
Similar to before, we will also define a couple of functions to help us select random examples during training:
# Random item from a list
def randomChoice(l):
return l[random.randint(0, len(l) - 1)]
# Get a random category and random line from that category
def randomTrainingPair():
category = randomChoice(category_labels)
line = randomChoice(category_lines[category])
return category, line
# Make category, input, and target tensors from a random category, line pair
def randomTrainingExample():
category, line = randomTrainingPair()
category_tensor = categoryTensor(category)
input_line_tensor = inputTensor(line)
target_line_tensor = outputTensor(line)
return category_tensor, input_line_tensor, target_line_tensor
To get a basic understanding of these functions, we can consider a randomly chosen training sample:
## Try it out
randomTrainingExample()
(tensor([[1., 0., 0., 0., 0., 0.]]), tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]), tensor([ 0, 13, 6, 58]))
In order, these three tensor outputs are:
Next we'll define a function that will generate an output name from the network when given an initial character. We'll arbitrarily limit the generated name to a maximum length of 15, and we'll use the network architecture we previous defined with randomly initialized weights to generate a name:
max_length = 15
gen_rnn = my_gen_rnn(n_letters, 128, n_letters)
# Sample using a given category and starting letter
def sample(category, start_letter):
## We are just sampling, so we don't want to store info used in gradient calculations
with torch.no_grad():
category_tensor = categoryTensor(category) ## create category tensor of input category
input = inputTensor(start_letter) ## intialize input tensor as an encoding of the start letter
hidden = gen_rnn.initHidden() ## reset the initial hidden state
output_name = start_letter ## Use start letter as first piece of the output name
## Loop until reaching the max length or the stop character
for i in range(max_length):
output, hidden = gen_rnn(category_tensor, input[0], hidden) ## Get the next output and hidden state
topv, topi = output.topk(1) ## Identify the top predicted character's value and index position
topi = topi[0][0] ## Extract integer id of predicted char
if topi == n_letters - 1: ## Stop if its the stop character's ID
break
else:
letter = all_letters[topi] ## Convert integer id to the character
output_name += letter ## Add this character to the output
input = inputTensor(letter) ## Prep this letter as the next input
return output_name
We can see this function in action by providing a valid category label and initial character:
sample('English', 'B')
'BqqqqqqqqqqSqqSq'
Question 5:
topi == n_letters - 1
? Briefly explain.Similar to our previous example, we'll create a function that we can use to help train our network:
cost_fn = nn.CrossEntropyLoss()
gen_rnn = my_gen_rnn(n_letters, 256, n_letters)
learning_rate = 0.001
def train(category_tensor, input_line_tensor, target_line_tensor):
target_line_tensor.unsqueeze_(-1)
hidden = gen_rnn.initHidden()
gen_rnn.zero_grad()
cost = 0
for i in range(input_line_tensor.size(0)):
output, hidden = gen_rnn(category_tensor, input_line_tensor[i], hidden)
l = cost_fn(output, target_line_tensor[i])
cost += l
cost.backward()
for p in gen_rnn.parameters():
p.data.add_(p.grad.data, alpha=-learning_rate)
return output, cost.item() / input_line_tensor.size(0)
We'll again train the network using 10,000 randomly selected training examples:
n_iters = 20000
cost_every_n = 25
current_cost = 0
track_cost = []
for iter in range(1, n_iters + 1):
cat, il, ol = randomTrainingExample()
if -1 in ol: ### If an example happens to contain an unusual character we'll skip it
continue
output, cost = train(cat, il, ol)
current_cost += cost
# Save the cost every 25 iterations
if iter % cost_every_n == 0:
track_cost.append(current_cost/cost_every_n)
current_cost = 0
As shown below, we can see that the network's parameters have reached a point where the cost is no longer improving:
import matplotlib.pyplot as plt
plt.plot(track_cost)
plt.show()
At this point, we can use the sample
function we created earlier to explore some of the names we can generate. The code below provides a template for looking at various names that are generated from some seed characters:
test_letter = 'Br'
print('Korean:',sample('Korean', test_letter),
'\nJapanese:', sample('Japanese', test_letter),
'\nChinese:', sample('Chinese', test_letter),
'\nEnglish:', sample('English', test_letter),
'\nIrish:', sample('Irish', test_letter),
'\nRussian:', sample('Russian', test_letter))
Korean: Bron Japanese: Braka Chinese: Bran English: Brare Irish: Branghan Russian: Braranov
You can re-run the same commands several times and see slightly different results due to the dropout layer involved in the creation of the network's output.
Question 6: To verify that you've explored training this model and using it to generate output, modify the print
command given above to make test_letter
a seeed string of your choice. Include your printed results, and write 1-2 sentences commenting upon how you view the effectiveness of this model.
Acknowledgements: The contents of this lab were adapted from the following tutorials: