Skip to content

Feature enhancement #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
56 changes: 34 additions & 22 deletions GMF.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
'''
Updated on Jan 29, 2025
Created on Aug 9, 2016

Keras Implementation of Generalized Matrix Factorization (GMF) recommender model in:
He Xiangnan et al. Neural Collaborative Filtering. In WWW 2017.

@author: Xiangnan He (xiangnanhe@gmail.com)
@Updated by: Amrita Yadav

'''
import numpy as np
import theano.tensor as T
import keras
from keras import backend as K
from keras import initializations
from keras.models import Sequential, Model, load_model, save_model
from keras.layers.core import Dense, Lambda, Activation
from keras.layers import Embedding, Input, Dense, merge, Reshape, Merge, Flatten
# from keras import initializations
from keras import initializers
from keras.initializers import RandomNormal, lecun_uniform
from keras.models import Model
from keras.layers import Dense
from keras.layers import Embedding, Input, Dense, Multiply, Flatten
from keras.optimizers import Adagrad, Adam, SGD, RMSprop
from keras.regularizers import l2
from Dataset import Dataset
Expand Down Expand Up @@ -52,31 +55,31 @@ def parse_args():
return parser.parse_args()

def init_normal(shape, name=None):
return initializations.normal(shape, scale=0.01, name=name)
return initializers.normal(shape, scale=0.01, name=name)

def get_model(num_users, num_items, latent_dim, regs=[0,0]):
# Input variables
user_input = Input(shape=(1,), dtype='int32', name = 'user_input')
item_input = Input(shape=(1,), dtype='int32', name = 'item_input')

MF_Embedding_User = Embedding(input_dim = num_users, output_dim = latent_dim, name = 'user_embedding',
init = init_normal, W_regularizer = l2(regs[0]), input_length=1)
embeddings_initializer=RandomNormal(mean=0.0, stddev=0.05), embeddings_regularizer = l2(regs[0]), input_length=1)
MF_Embedding_Item = Embedding(input_dim = num_items, output_dim = latent_dim, name = 'item_embedding',
init = init_normal, W_regularizer = l2(regs[1]), input_length=1)
embeddings_initializer=RandomNormal(mean=0.0, stddev=0.05), embeddings_regularizer = l2(regs[1]), input_length=1)

# Crucial to flatten an embedding vector!
user_latent = Flatten()(MF_Embedding_User(user_input))
item_latent = Flatten()(MF_Embedding_Item(item_input))

# Element-wise product of user and item embeddings
predict_vector = merge([user_latent, item_latent], mode = 'mul')
predict_vector = Multiply()([user_latent, item_latent])

# Final prediction layer
#prediction = Lambda(lambda x: K.sigmoid(K.sum(x)), output_shape=(1,))(predict_vector)
prediction = Dense(1, activation='sigmoid', init='lecun_uniform', name = 'prediction')(predict_vector)
prediction = Dense(1, activation='sigmoid', kernel_initializer=lecun_uniform(), name = 'prediction')(predict_vector)

model = Model(input=[user_input, item_input],
output=prediction)
model = Model(inputs=[user_input, item_input],
outputs=prediction)

return model

Expand All @@ -89,9 +92,9 @@ def get_train_instances(train, num_negatives):
item_input.append(i)
labels.append(1)
# negative instances
for t in xrange(num_negatives):
for t in range(num_negatives):
j = np.random.randint(num_items)
while train.has_key((u, j)):
while (u, j) in train:
j = np.random.randint(num_items)
user_input.append(u)
item_input.append(j)
Expand All @@ -112,6 +115,11 @@ def get_train_instances(train, num_negatives):
topK = 10
evaluation_threads = 1 #mp.cpu_count()
print("GMF arguments: %s" %(args))

# # When saving only the model weights, use:
# model_out_file = 'Pretrain/%s_GMF_%d_%d.weights.h5' %(args.dataset, num_factors, time())

# If you want to save the full model (architecture + weights)
model_out_file = 'Pretrain/%s_GMF_%d_%d.h5' %(args.dataset, num_factors, time())

# Loading data
Expand All @@ -125,13 +133,13 @@ def get_train_instances(train, num_negatives):
# Build model
model = get_model(num_users, num_items, num_factors, regs)
if learner.lower() == "adagrad":
model.compile(optimizer=Adagrad(lr=learning_rate), loss='binary_crossentropy')
model.compile(optimizer=Adagrad(learning_rate=learning_rate), loss='binary_crossentropy')
elif learner.lower() == "rmsprop":
model.compile(optimizer=RMSprop(lr=learning_rate), loss='binary_crossentropy')
model.compile(optimizer=RMSprop(learning_rate=learning_rate), loss='binary_crossentropy')
elif learner.lower() == "adam":
model.compile(optimizer=Adam(lr=learning_rate), loss='binary_crossentropy')
model.compile(optimizer=Adam(learning_rate=learning_rate), loss='binary_crossentropy')
else:
model.compile(optimizer=SGD(lr=learning_rate), loss='binary_crossentropy')
model.compile(optimizer=SGD(learning_rate=learning_rate), loss='binary_crossentropy')
#print(model.summary())

# Init performance
Expand All @@ -144,15 +152,15 @@ def get_train_instances(train, num_negatives):

# Train model
best_hr, best_ndcg, best_iter = hr, ndcg, -1
for epoch in xrange(epochs):
for epoch in range(epochs):
t1 = time()
# Generate training instances
user_input, item_input, labels = get_train_instances(train, num_negatives)

# Training
hist = model.fit([np.array(user_input), np.array(item_input)], #input
np.array(labels), # labels
batch_size=batch_size, nb_epoch=1, verbose=0, shuffle=True)
batch_size=batch_size, epochs=1, verbose=0, shuffle=True)
t2 = time()

# Evaluation
Expand All @@ -164,7 +172,11 @@ def get_train_instances(train, num_negatives):
if hr > best_hr:
best_hr, best_ndcg, best_iter = hr, ndcg, epoch
if args.out > 0:
model.save_weights(model_out_file, overwrite=True)
# # When saving only the model weights, use:
# model.save_weights(model_out_file, overwrite=True)

# If you want to save the full model (architecture + weights)
model.save(model_out_file)

print("End. Best Iteration %d: HR = %.4f, NDCG = %.4f. " %(best_iter, best_hr, best_ndcg))
if args.out > 0:
Expand Down
69 changes: 40 additions & 29 deletions MLP.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
'''
Created on Aug 9, 2016
Updated on Jan 29, 2025
Keras Implementation of Multi-Layer Perceptron (GMF) recommender model in:
He Xiangnan et al. Neural Collaborative Filtering. In WWW 2017.

@author: Xiangnan He (xiangnanhe@gmail.com)
@Updated by: Amrita Yadav
'''

import numpy as np

import theano
import theano.tensor as T
import keras
from keras import backend as K
from keras import initializations
from keras.regularizers import l2, activity_l2
from keras.models import Sequential, Graph, Model
from keras.layers.core import Dense, Lambda, Activation
from keras.layers import Embedding, Input, Dense, merge, Reshape, Merge, Flatten, Dropout
from keras.constraints import maxnorm
from keras import initializers
from keras.initializers import RandomNormal, lecun_uniform
from keras.regularizers import l2
from keras.models import Model
from keras.layers import Dense
from keras.layers import Embedding, Input, Dense, Flatten, Concatenate
from keras.optimizers import Adagrad, Adam, SGD, RMSprop
from evaluate import evaluate_model
from Dataset import Dataset
Expand Down Expand Up @@ -54,7 +53,7 @@ def parse_args():
return parser.parse_args()

def init_normal(shape, name=None):
return initializations.normal(shape, scale=0.01, name=name)
return initializers.normal(shape, scale=0.01, name=name)

def get_model(num_users, num_items, layers = [20,10], reg_layers=[0,0]):
assert len(layers) == len(reg_layers)
Expand All @@ -63,28 +62,30 @@ def get_model(num_users, num_items, layers = [20,10], reg_layers=[0,0]):
user_input = Input(shape=(1,), dtype='int32', name = 'user_input')
item_input = Input(shape=(1,), dtype='int32', name = 'item_input')

MLP_Embedding_User = Embedding(input_dim = num_users, output_dim = layers[0]/2, name = 'user_embedding',
init = init_normal, W_regularizer = l2(reg_layers[0]), input_length=1)
MLP_Embedding_Item = Embedding(input_dim = num_items, output_dim = layers[0]/2, name = 'item_embedding',
init = init_normal, W_regularizer = l2(reg_layers[0]), input_length=1)
print("\n\n Dimension : ",int(layers[0]/2),layers[0]/2 )

MLP_Embedding_User = Embedding(input_dim = num_users, output_dim = int(layers[0]/2), name = 'user_embedding',
embeddings_initializer=RandomNormal(mean=0.0, stddev=0.05), embeddings_regularizer = l2(reg_layers[0]), input_length=1)
MLP_Embedding_Item = Embedding(input_dim = num_items, output_dim = int(layers[0]/2), name = 'item_embedding',
embeddings_initializer=RandomNormal(mean=0.0, stddev=0.05), embeddings_regularizer = l2(reg_layers[0]), input_length=1)

# Crucial to flatten an embedding vector!
user_latent = Flatten()(MLP_Embedding_User(user_input))
item_latent = Flatten()(MLP_Embedding_Item(item_input))

# The 0-th layer is the concatenation of embedding layers
vector = merge([user_latent, item_latent], mode = 'concat')
vector = Concatenate()([user_latent, item_latent])

# MLP layers
for idx in xrange(1, num_layer):
layer = Dense(layers[idx], W_regularizer= l2(reg_layers[idx]), activation='relu', name = 'layer%d' %idx)
for idx in range(1, num_layer):
layer = Dense(layers[idx], kernel_regularizer = l2(reg_layers[idx]), activation='relu', name = 'layer%d' %idx)
vector = layer(vector)

# Final prediction layer
prediction = Dense(1, activation='sigmoid', init='lecun_uniform', name = 'prediction')(vector)
prediction = Dense(1, activation='sigmoid', kernel_initializer=lecun_uniform(), name = 'prediction')(vector)

model = Model(input=[user_input, item_input],
output=prediction)
model = Model(inputs=[user_input, item_input],
outputs=prediction)

return model

Expand All @@ -97,9 +98,9 @@ def get_train_instances(train, num_negatives):
item_input.append(i)
labels.append(1)
# negative instances
for t in xrange(num_negatives):
for t in range(num_negatives):
j = np.random.randint(num_items)
while train.has_key((u, j)):
while (u, j) in train:
j = np.random.randint(num_items)
user_input.append(u)
item_input.append(j)
Expand All @@ -122,6 +123,12 @@ def get_train_instances(train, num_negatives):
topK = 10
evaluation_threads = 1 #mp.cpu_count()
print("MLP arguments: %s " %(args))

# # When saving only the model weights, use:
#model_out_file = 'Pretrain/%s_MLP_%s_%d.weights.h5' %(args.dataset, args.layers, time())


# If you want to save the full model (architecture + weights)
model_out_file = 'Pretrain/%s_MLP_%s_%d.h5' %(args.dataset, args.layers, time())

# Loading data
Expand All @@ -135,13 +142,13 @@ def get_train_instances(train, num_negatives):
# Build model
model = get_model(num_users, num_items, layers, reg_layers)
if learner.lower() == "adagrad":
model.compile(optimizer=Adagrad(lr=learning_rate), loss='binary_crossentropy')
model.compile(optimizer=Adagrad(learning_rate=learning_rate), loss='binary_crossentropy')
elif learner.lower() == "rmsprop":
model.compile(optimizer=RMSprop(lr=learning_rate), loss='binary_crossentropy')
model.compile(optimizer=RMSprop(learning_rate=learning_rate), loss='binary_crossentropy')
elif learner.lower() == "adam":
model.compile(optimizer=Adam(lr=learning_rate), loss='binary_crossentropy')
model.compile(optimizer=Adam(learning_rate=learning_rate), loss='binary_crossentropy')
else:
model.compile(optimizer=SGD(lr=learning_rate), loss='binary_crossentropy')
model.compile(optimizer=SGD(learning_rate=learning_rate), loss='binary_crossentropy')

# Check Init performance
t1 = time()
Expand All @@ -151,15 +158,15 @@ def get_train_instances(train, num_negatives):

# Train model
best_hr, best_ndcg, best_iter = hr, ndcg, -1
for epoch in xrange(epochs):
for epoch in range(epochs):
t1 = time()
# Generate training instances
user_input, item_input, labels = get_train_instances(train, num_negatives)

# Training
hist = model.fit([np.array(user_input), np.array(item_input)], #input
np.array(labels), # labels
batch_size=batch_size, nb_epoch=1, verbose=0, shuffle=True)
batch_size=batch_size, epochs=1, verbose=0, shuffle=True)
t2 = time()

# Evaluation
Expand All @@ -171,7 +178,11 @@ def get_train_instances(train, num_negatives):
if hr > best_hr:
best_hr, best_ndcg, best_iter = hr, ndcg, epoch
if args.out > 0:
model.save_weights(model_out_file, overwrite=True)
# # When saving only the model weights, use:
# model.save_weights(model_out_file, overwrite=True)

# If you want to save the full model (architecture + weights)
model.save(model_out_file)

print("End. Best Iteration %d: HR = %.4f, NDCG = %.4f. " %(best_iter, best_hr, best_ndcg))
if args.out > 0:
Expand Down
Loading