Show/Hide the code
1
2
3
4
from fastai.collab import *
from fastai.tabular.all import *

path = untar_data(URLs.ML_100k)
Show/Hide the code
1
2
3
4
5
6
7
ratings = pd.read_csv(
    path / "u.data",
    delimiter="\t",
    header=None,
    names=["user", "movie", "rating", "timestamp"],
)
ratings.head()

usermovieratingtimestamp
01962423881250949
11863023891717742
2223771878887116
3244512880606923
41663461886397596
Show/Hide the code
1
2
pivot = pd.pivot_table(ratings, values="rating", index="user", columns="movie")
pivot

movie12345678910...1673167416751676167716781679168016811682
user
15.03.04.03.03.05.04.01.05.03.0...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
24.0NaNNaNNaNNaNNaNNaNNaNNaN2.0...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
4NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
54.03.0NaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
..................................................................
939NaNNaNNaNNaNNaNNaNNaNNaN5.0NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
940NaNNaNNaN2.0NaNNaN4.05.03.0NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
9415.0NaNNaNNaNNaNNaN4.0NaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
942NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
943NaN5.0NaNNaNNaNNaNNaNNaN3.0NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN

943 rows × 1682 columns

Show/Hide the code
1
2
3
4
dense_rows = pivot.notna().sum(axis=1)
dense_cols = pivot.notna().sum(axis=0)
selected = pivot.loc[dense_rows.nlargest(20).index, dense_cols.nlargest(20).index]
print(selected.fillna(" ").to_string())
movie  50   258  100  181  294  286  288  1    300  121  174  127  56   7    98   237  117  172  222  204
user                                                                                                     
405    5.0            5.0            5.0                 5.0  5.0  4.0       4.0            5.0       5.0
655    4.0  2.0  3.0  3.0  3.0  3.0  3.0  2.0  3.0  3.0  3.0  5.0  3.0  3.0  4.0  3.0  2.0  4.0  2.0  3.0
13     5.0  4.0  5.0  5.0  2.0  3.0  1.0  3.0  1.0  5.0  4.0  5.0  5.0  2.0  4.0  5.0  3.0  5.0  3.0  5.0
450    5.0  4.0  4.0  4.0  4.0  4.0  3.0  4.0  4.0  3.0  5.0  5.0  4.0  4.0  4.0  5.0  4.0  4.0  3.0  4.0
276    5.0  5.0  5.0  5.0  4.0       4.0  5.0  4.0  4.0  5.0  5.0  5.0  5.0  5.0  5.0  4.0  5.0  4.0  5.0
416    5.0  5.0  5.0  5.0  4.0  5.0  5.0  5.0  4.0  5.0  5.0  5.0  5.0  4.0  5.0  3.0  5.0  5.0       5.0
537    4.0  4.0  4.0  2.0  1.0  3.0  2.0  2.0  1.0  1.0  3.0  5.0  5.0  4.0  3.0  3.0  2.0  3.0  2.0  3.0
303    5.0  4.0  5.0  5.0  4.0  5.0  4.0  5.0  1.0  3.0  5.0  5.0  5.0  4.0  5.0  5.0  3.0  5.0  3.0  4.0
234    4.0  2.0  4.0  3.0  3.0  3.0  3.0  3.0  3.0       3.0  4.0  3.0  2.0  4.0  3.0  2.0  3.0  3.0  2.0
393    5.0  4.0  1.0  4.0  4.0       3.0  3.0       4.0            2.0  4.0       4.0  4.0  5.0  4.0  4.0
181         3.0  3.0       2.0  1.0  4.0  3.0  3.0  4.0                 4.0       5.0  2.0       4.0     
279    3.0       4.0  3.0  2.0       3.0  3.0       4.0  4.0       4.0  5.0            5.0  2.0  1.0  3.0
429    5.0  4.0  5.0  5.0            3.0  3.0  3.0  3.0  4.0  4.0  4.0  2.0  4.0  3.0  4.0  5.0  4.0  4.0
846    5.0  3.0       5.0  3.0       4.0                 5.0  5.0  5.0       4.0            4.0       3.0
7      5.0  4.0  5.0  3.0  1.0  4.0  4.0       4.0  5.0  5.0  5.0  5.0  5.0  4.0  5.0       4.0       5.0
94     5.0  5.0  5.0  4.0       4.0  3.0  4.0       2.0  4.0  5.0  5.0  4.0  4.0            4.0  3.0  4.0
682    5.0  3.0  3.0  5.0  3.0       4.0  4.0  2.0  4.0  4.0  5.0  4.0  4.0  4.0  3.0  4.0  5.0  4.0  3.0
308    5.0       5.0  4.0  3.0       4.0  4.0       3.0  4.0  4.0  5.0  4.0  3.0  3.0  3.0  4.0       4.0
92     5.0  4.0  5.0  4.0  3.0       3.0  4.0       5.0  5.0       5.0  4.0  5.0  4.0  4.0  4.0  4.0  4.0
293    5.0  3.0  4.0  3.0  2.0  3.0  3.0  2.0  2.0  3.0  5.0  5.0  4.0  3.0  4.0  3.0  3.0  5.0  3.0  3.0
Show/Hide the code
1
2
3
4
5
6
7
8
9
movies = pd.read_csv(
    path / "u.item",
    delimiter="|",
    encoding="latin-1",
    usecols=(0, 1),
    names=("movie", "title"),
    header=None,
)
movies.head()

movietitle
01Toy Story (1995)
12GoldenEye (1995)
23Four Rooms (1995)
34Get Shorty (1995)
45Copycat (1995)
Show/Hide the code
1
2
ratings = ratings.merge(movies)
ratings.head()

usermovieratingtimestamptitle
01962423881250949Kolya (1996)
11863023891717742L.A. Confidential (1997)
2223771878887116Heavyweights (1994)
3244512880606923Legends of the Fall (1994)
41663461886397596Jackie Brown (1997)
Show/Hide the code
1
2
3
4
5
6
# Assumptions here:
# By default, it takes the first column for the user,
# the second column for the item and the third column for the ratings.
# In this case, we specify title column instead of the second column for the item
dls = CollabDataLoaders.from_df(ratings, item_name="title", bs=64)
dls.show_batch()
usertitlerating
0597Godfather, The (1972)4
1814Evil Dead II (1987)2
2234Mother (1996)2
3176Cop Land (1997)3
4778Cool Runnings (1993)1
5521Die Hard 2 (1990)4
6904Bed of Roses (1996)5
7151Independence Day (ID4) (1996)5
8880Devil's Own, The (1997)2
915Peacemaker, The (1997)3
Show/Hide the code
1
2
3
4
5
6
n_users = len(dls.classes["user"])
n_movies = len(dls.classes["title"])
n_factors = 5

user_factors = torch.randn(n_users, n_factors)
movie_factors = torch.randn(n_movies, n_factors)
Show/Hide the code
1
2
one_hot_3 = one_hot(3, n_users).float()
user_factors.T @ one_hot_3
tensor([-1.1552, -1.3241, -0.1439,  1.1268, -0.5780])
Show/Hide the code
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class DotProduct(Module):
    def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):
        self.user_factors = Embedding(n_users, n_factors)
        self.user_bias = Embedding(n_users, 1)
        self.movie_factors = Embedding(n_movies, n_factors)
        self.movie_bias = Embedding(n_movies, 1)
        self.y_range = y_range

    def forward(self, x):
        users = self.user_factors(x[:, 0])
        movies = self.movie_factors(x[:, 1])
        res = (users * movies).sum(dim=1)
        res += self.user_bias(x[:, 0]).squeeze() + self.movie_bias(x[:, 1]).squeeze()
        return sigmoid_range(res, *self.y_range)
Show/Hide the code
1
2
x, y = dls.one_batch()
x.shape, y.shape
(torch.Size([64, 2]), torch.Size([64, 1]))
Show/Hide the code
1
2
3
model = DotProduct(n_users, n_movies, 50)
learn = Learner(dls, model, loss_func=MSELossFlat())
learn.fit_one_cycle(5, 5e-3, wd=0.1)
epochtrain_lossvalid_losstime
00.9078770.96935200:07
10.6734560.90703100:05
20.5239830.87779500:07
30.4605340.86338800:05
40.4219440.85883100:05
Show/Hide the code
1
2
def create_params(size):
    return nn.Parameter(torch.zeros(*size).normal_(0, 0.01))
Show/Hide the code
1
2
3
movie_bias = learn.model.movie_bias.weight.squeeze()
idx = movie_bias.argsort(descending=True)[:5]
[dls.classes["title"][i] for i in idx]
['Shawshank Redemption, The (1994)',
 "Schindler's List (1993)",
 'Star Wars (1977)',
 'L.A. Confidential (1997)',
 'Titanic (1997)']
Show/Hide the code
1
2
learn = collab_learner(dls, n_factors=50, y_range=(0, 5.5), metrics=rmse)
learn.fit_one_cycle(10, 5e-3, wd=0.1)
epochtrain_lossvalid_loss_rmsetime
01.0134641.0196711.00978800:07
10.7690470.9080770.95293100:03
20.6025420.8985110.94789800:06
30.4970940.8939380.94548300:06
40.4334810.8940630.94554900:03
50.3749460.8896140.94319400:05
60.3260300.8855890.94105800:05
70.3010460.8825710.93945300:03
80.2782230.8802050.93819200:04
90.2797070.8798050.93797900:04
Show/Hide the code
1
2
3
4
5
movie_factors = learn.model.i_weight.weight
idx = dls.classes["title"].o2i["Silence of the Lambs, The (1991)"]
distances = nn.CosineSimilarity(dim=1)(movie_factors, movie_factors[idx][None])
idxs = distances.argsort(descending=True)[0:20]
[dls.classes["title"][i] for i in idxs]
['Silence of the Lambs, The (1991)',
 'Manchurian Candidate, The (1962)',
 'Farewell to Arms, A (1932)',
 'Meet John Doe (1941)',
 'Wedding Gift, The (1994)',
 'Fugitive, The (1993)',
 'To Catch a Thief (1955)',
 'Ben-Hur (1959)',
 'Lost Horizon (1937)',
 "It's a Wonderful Life (1946)",
 'Shawshank Redemption, The (1994)',
 'Arsenic and Old Lace (1944)',
 'Dial M for Murder (1954)',
 'Pather Panchali (1955)',
 'Third Man, The (1949)',
 'Gaslight (1944)',
 'Mr. Smith Goes to Washington (1939)',
 'Guantanamera (1994)',
 'Great Escape, The (1963)',
 'Once Were Warriors (1994)']
使用 Hugo 构建
主题 StackJimmy 设计