mirror of
https://github.com/lemeow125/DRF_Template.git
synced 2024-11-17 04:09:25 +08:00
Clean up docker-compose and run Black formatter over entire codebase
This commit is contained in:
parent
6c232b3e89
commit
069aba80b1
60 changed files with 1946 additions and 1485 deletions
1
Pipfile
1
Pipfile
|
@ -35,6 +35,7 @@ gunicorn = "*"
|
||||||
django-silk = "*"
|
django-silk = "*"
|
||||||
django-redis = "*"
|
django-redis = "*"
|
||||||
granian = "*"
|
granian = "*"
|
||||||
|
black = "*"
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
|
|
||||||
|
|
1492
Pipfile.lock
generated
1492
Pipfile.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -6,11 +6,13 @@ from .models import CustomUser
|
||||||
|
|
||||||
class CustomUserAdmin(UserAdmin):
|
class CustomUserAdmin(UserAdmin):
|
||||||
model = CustomUser
|
model = CustomUser
|
||||||
list_display = ('id', 'is_active', 'user_group',) + UserAdmin.list_display
|
list_display = (
|
||||||
|
"id",
|
||||||
|
"is_active",
|
||||||
|
"user_group",
|
||||||
|
) + UserAdmin.list_display
|
||||||
# Editable fields per instance
|
# Editable fields per instance
|
||||||
fieldsets = UserAdmin.fieldsets + (
|
fieldsets = UserAdmin.fieldsets + ((None, {"fields": ("avatar",)}),)
|
||||||
(None, {'fields': ('avatar',)}),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
admin.site.register(CustomUser, CustomUserAdmin)
|
admin.site.register(CustomUser, CustomUserAdmin)
|
||||||
|
|
|
@ -2,8 +2,8 @@ from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class AccountsConfig(AppConfig):
|
class AccountsConfig(AppConfig):
|
||||||
default_auto_field = 'django.db.models.BigAutoField'
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
name = 'accounts'
|
name = "accounts"
|
||||||
|
|
||||||
def ready(self):
|
def ready(self):
|
||||||
import accounts.signals
|
import accounts.signals
|
||||||
|
|
|
@ -13,38 +13,145 @@ class Migration(migrations.Migration):
|
||||||
initial = True
|
initial = True
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
('auth', '0012_alter_user_first_name_max_length'),
|
("auth", "0012_alter_user_first_name_max_length"),
|
||||||
('user_groups', '0001_initial'),
|
("user_groups", "0001_initial"),
|
||||||
]
|
]
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name='CustomUser',
|
name="CustomUser",
|
||||||
fields=[
|
fields=[
|
||||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
(
|
||||||
('password', models.CharField(max_length=128, verbose_name='password')),
|
"id",
|
||||||
('last_login', models.DateTimeField(blank=True, null=True, verbose_name='last login')),
|
models.BigAutoField(
|
||||||
('is_superuser', models.BooleanField(default=False, help_text='Designates that this user has all permissions without explicitly assigning them.', verbose_name='superuser status')),
|
auto_created=True,
|
||||||
('username', models.CharField(error_messages={'unique': 'A user with that username already exists.'}, help_text='Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.', max_length=150, unique=True, validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], verbose_name='username')),
|
primary_key=True,
|
||||||
('first_name', models.CharField(blank=True, max_length=150, verbose_name='first name')),
|
serialize=False,
|
||||||
('last_name', models.CharField(blank=True, max_length=150, verbose_name='last name')),
|
verbose_name="ID",
|
||||||
('email', models.EmailField(blank=True, max_length=254, verbose_name='email address')),
|
),
|
||||||
('is_staff', models.BooleanField(default=False, help_text='Designates whether the user can log into this admin site.', verbose_name='staff status')),
|
),
|
||||||
('is_active', models.BooleanField(default=True, help_text='Designates whether this user should be treated as active. Unselect this instead of deleting accounts.', verbose_name='active')),
|
("password", models.CharField(max_length=128, verbose_name="password")),
|
||||||
('date_joined', models.DateTimeField(default=django.utils.timezone.now, verbose_name='date joined')),
|
(
|
||||||
('avatar', django_resized.forms.ResizedImageField(crop=None, force_format='WEBP', keep_meta=True, null=True, quality=100, scale=None, size=[1920, 1080], upload_to='avatars/')),
|
"last_login",
|
||||||
('onboarding', models.BooleanField(default=True)),
|
models.DateTimeField(
|
||||||
('groups', models.ManyToManyField(blank=True, help_text='The groups this user belongs to. A user will get all permissions granted to each of their groups.', related_name='user_set', related_query_name='user', to='auth.group', verbose_name='groups')),
|
blank=True, null=True, verbose_name="last login"
|
||||||
('user_group', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to='user_groups.usergroup')),
|
),
|
||||||
('user_permissions', models.ManyToManyField(blank=True, help_text='Specific permissions for this user.', related_name='user_set', related_query_name='user', to='auth.permission', verbose_name='user permissions')),
|
),
|
||||||
|
(
|
||||||
|
"is_superuser",
|
||||||
|
models.BooleanField(
|
||||||
|
default=False,
|
||||||
|
help_text="Designates that this user has all permissions without explicitly assigning them.",
|
||||||
|
verbose_name="superuser status",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"username",
|
||||||
|
models.CharField(
|
||||||
|
error_messages={
|
||||||
|
"unique": "A user with that username already exists."
|
||||||
|
},
|
||||||
|
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
|
||||||
|
max_length=150,
|
||||||
|
unique=True,
|
||||||
|
validators=[
|
||||||
|
django.contrib.auth.validators.UnicodeUsernameValidator()
|
||||||
|
],
|
||||||
|
verbose_name="username",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"first_name",
|
||||||
|
models.CharField(
|
||||||
|
blank=True, max_length=150, verbose_name="first name"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"last_name",
|
||||||
|
models.CharField(
|
||||||
|
blank=True, max_length=150, verbose_name="last name"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"email",
|
||||||
|
models.EmailField(
|
||||||
|
blank=True, max_length=254, verbose_name="email address"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"is_staff",
|
||||||
|
models.BooleanField(
|
||||||
|
default=False,
|
||||||
|
help_text="Designates whether the user can log into this admin site.",
|
||||||
|
verbose_name="staff status",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"is_active",
|
||||||
|
models.BooleanField(
|
||||||
|
default=True,
|
||||||
|
help_text="Designates whether this user should be treated as active. Unselect this instead of deleting accounts.",
|
||||||
|
verbose_name="active",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"date_joined",
|
||||||
|
models.DateTimeField(
|
||||||
|
default=django.utils.timezone.now, verbose_name="date joined"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"avatar",
|
||||||
|
django_resized.forms.ResizedImageField(
|
||||||
|
crop=None,
|
||||||
|
force_format="WEBP",
|
||||||
|
keep_meta=True,
|
||||||
|
null=True,
|
||||||
|
quality=100,
|
||||||
|
scale=None,
|
||||||
|
size=[1920, 1080],
|
||||||
|
upload_to="avatars/",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("onboarding", models.BooleanField(default=True)),
|
||||||
|
(
|
||||||
|
"groups",
|
||||||
|
models.ManyToManyField(
|
||||||
|
blank=True,
|
||||||
|
help_text="The groups this user belongs to. A user will get all permissions granted to each of their groups.",
|
||||||
|
related_name="user_set",
|
||||||
|
related_query_name="user",
|
||||||
|
to="auth.group",
|
||||||
|
verbose_name="groups",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"user_group",
|
||||||
|
models.ForeignKey(
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
to="user_groups.usergroup",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"user_permissions",
|
||||||
|
models.ManyToManyField(
|
||||||
|
blank=True,
|
||||||
|
help_text="Specific permissions for this user.",
|
||||||
|
related_name="user_set",
|
||||||
|
related_query_name="user",
|
||||||
|
to="auth.permission",
|
||||||
|
verbose_name="user permissions",
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
options={
|
options={
|
||||||
'verbose_name': 'user',
|
"verbose_name": "user",
|
||||||
'verbose_name_plural': 'users',
|
"verbose_name_plural": "users",
|
||||||
'abstract': False,
|
"abstract": False,
|
||||||
},
|
},
|
||||||
managers=[
|
managers=[
|
||||||
('objects', django.contrib.auth.models.UserManager()),
|
("objects", django.contrib.auth.models.UserManager()),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -15,14 +15,16 @@ class CustomUser(AbstractUser):
|
||||||
# is_admin inherited from base user class
|
# is_admin inherited from base user class
|
||||||
|
|
||||||
avatar = ResizedImageField(
|
avatar = ResizedImageField(
|
||||||
null=True, force_format="WEBP", quality=100, upload_to='avatars/')
|
null=True, force_format="WEBP", quality=100, upload_to="avatars/"
|
||||||
|
)
|
||||||
|
|
||||||
# Used for onboarding processes
|
# Used for onboarding processes
|
||||||
# Set this to False later on once the user makes actions
|
# Set this to False later on once the user makes actions
|
||||||
onboarding = models.BooleanField(default=True)
|
onboarding = models.BooleanField(default=True)
|
||||||
|
|
||||||
user_group = models.ForeignKey(
|
user_group = models.ForeignKey(
|
||||||
'user_groups.UserGroup', on_delete=models.SET_NULL, null=True)
|
"user_groups.UserGroup", on_delete=models.SET_NULL, null=True
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def group_member(self):
|
def group_member(self):
|
||||||
|
@ -57,4 +59,4 @@ class CustomUser(AbstractUser):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def admin_url(self):
|
def admin_url(self):
|
||||||
return reverse('admin:users_customuser_change', args=(self.pk,))
|
return reverse("admin:users_customuser_change", args=(self.pk,))
|
||||||
|
|
|
@ -8,13 +8,14 @@ from django.core.cache import cache
|
||||||
from django.core import exceptions as django_exceptions
|
from django.core import exceptions as django_exceptions
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
from django.contrib.auth.password_validation import validate_password
|
from django.contrib.auth.password_validation import validate_password
|
||||||
|
|
||||||
# There can be multiple subject instances with the same name, only differing in course, year level, and semester. We filter them here
|
# There can be multiple subject instances with the same name, only differing in course, year level, and semester. We filter them here
|
||||||
|
|
||||||
|
|
||||||
class SimpleCustomUserSerializer(ModelSerializer):
|
class SimpleCustomUserSerializer(ModelSerializer):
|
||||||
class Meta(BaseUserSerializer.Meta):
|
class Meta(BaseUserSerializer.Meta):
|
||||||
model = CustomUser
|
model = CustomUser
|
||||||
fields = ('id', 'username', 'email', 'full_name')
|
fields = ("id", "username", "email", "full_name")
|
||||||
|
|
||||||
|
|
||||||
class CustomUserSerializer(BaseUserSerializer):
|
class CustomUserSerializer(BaseUserSerializer):
|
||||||
|
@ -22,19 +23,36 @@ class CustomUserSerializer(BaseUserSerializer):
|
||||||
|
|
||||||
class Meta(BaseUserSerializer.Meta):
|
class Meta(BaseUserSerializer.Meta):
|
||||||
model = CustomUser
|
model = CustomUser
|
||||||
fields = ('id', 'username', 'email', 'avatar', 'first_name',
|
fields = (
|
||||||
'last_name', 'user_group', 'group_member', 'group_owner')
|
"id",
|
||||||
read_only_fields = ('id', 'username', 'email', 'user_group',
|
"username",
|
||||||
'group_member', 'group_owner')
|
"email",
|
||||||
|
"avatar",
|
||||||
|
"first_name",
|
||||||
|
"is_new",
|
||||||
|
"last_name",
|
||||||
|
"user_group",
|
||||||
|
"group_member",
|
||||||
|
"group_owner",
|
||||||
|
)
|
||||||
|
read_only_fields = (
|
||||||
|
"id",
|
||||||
|
"username",
|
||||||
|
"email",
|
||||||
|
"user_group",
|
||||||
|
"group_member",
|
||||||
|
"group_owner",
|
||||||
|
)
|
||||||
|
|
||||||
def to_representation(self, instance):
|
def to_representation(self, instance):
|
||||||
representation = super().to_representation(instance)
|
representation = super().to_representation(instance)
|
||||||
representation['user_group'] = SimpleUserGroupSerializer(
|
representation["user_group"] = SimpleUserGroupSerializer(
|
||||||
instance.user_group, many=False).data
|
instance.user_group, many=False
|
||||||
|
).data
|
||||||
return representation
|
return representation
|
||||||
|
|
||||||
def update(self, instance, validated_data):
|
def update(self, instance, validated_data):
|
||||||
cache.delete(f'user:{instance.id}')
|
cache.delete(f"user:{instance.id}")
|
||||||
return super().update(instance, validated_data)
|
return super().update(instance, validated_data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,16 +60,18 @@ class UserRegistrationSerializer(serializers.ModelSerializer):
|
||||||
email = serializers.EmailField(required=True)
|
email = serializers.EmailField(required=True)
|
||||||
username = serializers.CharField(required=True)
|
username = serializers.CharField(required=True)
|
||||||
password = serializers.CharField(
|
password = serializers.CharField(
|
||||||
write_only=True, style={'input_type': 'password', 'placeholder': 'Password'})
|
write_only=True, style={"input_type": "password", "placeholder": "Password"}
|
||||||
|
)
|
||||||
first_name = serializers.CharField(
|
first_name = serializers.CharField(
|
||||||
required=True, allow_blank=False, allow_null=False)
|
required=True, allow_blank=False, allow_null=False
|
||||||
|
)
|
||||||
last_name = serializers.CharField(
|
last_name = serializers.CharField(
|
||||||
required=True, allow_blank=False, allow_null=False)
|
required=True, allow_blank=False, allow_null=False
|
||||||
|
)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = CustomUser
|
model = CustomUser
|
||||||
fields = ['email', 'username', 'password',
|
fields = ["email", "username", "password", "first_name", "last_name"]
|
||||||
'first_name', 'last_name']
|
|
||||||
|
|
||||||
def validate(self, attrs):
|
def validate(self, attrs):
|
||||||
user_attrs = attrs.copy()
|
user_attrs = attrs.copy()
|
||||||
|
@ -69,14 +89,15 @@ class UserRegistrationSerializer(serializers.ModelSerializer):
|
||||||
raise serializers.ValidationError({"password": errors})
|
raise serializers.ValidationError({"password": errors})
|
||||||
if self.Meta.model.objects.filter(username=attrs.get("username")).exists():
|
if self.Meta.model.objects.filter(username=attrs.get("username")).exists():
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
"A user with that username already exists.")
|
"A user with that username already exists."
|
||||||
|
)
|
||||||
return super().validate(attrs)
|
return super().validate(attrs)
|
||||||
|
|
||||||
def create(self, validated_data):
|
def create(self, validated_data):
|
||||||
user = self.Meta.model(**validated_data)
|
user = self.Meta.model(**validated_data)
|
||||||
user.username = validated_data['username']
|
user.username = validated_data["username"]
|
||||||
user.is_active = False
|
user.is_active = False
|
||||||
user.set_password(validated_data['password'])
|
user.set_password(validated_data["password"])
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
|
@ -12,38 +12,37 @@ import json
|
||||||
@receiver(post_migrate)
|
@receiver(post_migrate)
|
||||||
def create_users(sender, **kwargs):
|
def create_users(sender, **kwargs):
|
||||||
if sender.name == "accounts":
|
if sender.name == "accounts":
|
||||||
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f:
|
with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
|
||||||
seed_data = json.loads(f.read())
|
seed_data = json.loads(f.read())
|
||||||
for user in seed_data['users']:
|
for user in seed_data["users"]:
|
||||||
USER = CustomUser.objects.filter(
|
USER = CustomUser.objects.filter(email=user["email"]).first()
|
||||||
email=user['email']).first()
|
|
||||||
if not USER:
|
if not USER:
|
||||||
if user['password'] == 'USE_REGULAR':
|
if user["password"] == "USE_REGULAR":
|
||||||
password = get_secret('SEED_DATA_PASSWORD')
|
password = get_secret("SEED_DATA_PASSWORD")
|
||||||
elif user['password'] == 'USE_ADMIN':
|
elif user["password"] == "USE_ADMIN":
|
||||||
password = get_secret('SEED_DATA_ADMIN_PASSWORD')
|
password = get_secret("SEED_DATA_ADMIN_PASSWORD")
|
||||||
else:
|
else:
|
||||||
password = user['password']
|
password = user["password"]
|
||||||
if (user['is_superuser'] == True):
|
if user["is_superuser"] == True:
|
||||||
# Admin users are created regardless of SEED_DATA value
|
# Admin users are created regardless of SEED_DATA value
|
||||||
USER = CustomUser.objects.create_superuser(
|
USER = CustomUser.objects.create_superuser(
|
||||||
username=user['username'],
|
username=user["username"],
|
||||||
email=user['email'],
|
email=user["email"],
|
||||||
password=password,
|
password=password,
|
||||||
)
|
)
|
||||||
print('Created Superuser:', user['email'])
|
print("Created Superuser:", user["email"])
|
||||||
else:
|
else:
|
||||||
# Only create non-admin users if SEED_DATA=True
|
# Only create non-admin users if SEED_DATA=True
|
||||||
if SEED_DATA:
|
if SEED_DATA:
|
||||||
USER = CustomUser.objects.create_user(
|
USER = CustomUser.objects.create_user(
|
||||||
username=user['email'],
|
username=user["email"],
|
||||||
email=user['email'],
|
email=user["email"],
|
||||||
password=password,
|
password=password,
|
||||||
)
|
)
|
||||||
print('Created User:', user['email'])
|
print("Created User:", user["email"])
|
||||||
|
|
||||||
USER.first_name = user['first_name']
|
USER.first_name = user["first_name"]
|
||||||
USER.last_name = user['last_name']
|
USER.last_name = user["last_name"]
|
||||||
USER.is_active = True
|
USER.is_active = True
|
||||||
USER.save()
|
USER.save()
|
||||||
|
|
||||||
|
@ -51,53 +50,57 @@ def create_users(sender, **kwargs):
|
||||||
@receiver(post_migrate)
|
@receiver(post_migrate)
|
||||||
def create_celery_beat_schedules(sender, **kwargs):
|
def create_celery_beat_schedules(sender, **kwargs):
|
||||||
if sender.name == "django_celery_beat":
|
if sender.name == "django_celery_beat":
|
||||||
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f:
|
with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
|
||||||
seed_data = json.loads(f.read())
|
seed_data = json.loads(f.read())
|
||||||
# Creating Schedules
|
# Creating Schedules
|
||||||
for schedule in seed_data['schedules']:
|
for schedule in seed_data["schedules"]:
|
||||||
if schedule['type'] == 'crontab':
|
if schedule["type"] == "crontab":
|
||||||
# Check if Schedule already exists
|
# Check if Schedule already exists
|
||||||
SCHEDULE = CrontabSchedule.objects.filter(minute=schedule['minute'],
|
SCHEDULE = CrontabSchedule.objects.filter(
|
||||||
hour=schedule['hour'],
|
minute=schedule["minute"],
|
||||||
day_of_week=schedule['day_of_week'],
|
hour=schedule["hour"],
|
||||||
day_of_month=schedule['day_of_month'],
|
day_of_week=schedule["day_of_week"],
|
||||||
month_of_year=schedule['month_of_year'],
|
day_of_month=schedule["day_of_month"],
|
||||||
timezone=schedule['timezone']
|
month_of_year=schedule["month_of_year"],
|
||||||
).first()
|
timezone=schedule["timezone"],
|
||||||
|
).first()
|
||||||
# If it does not exist, create a new Schedule
|
# If it does not exist, create a new Schedule
|
||||||
if not SCHEDULE:
|
if not SCHEDULE:
|
||||||
SCHEDULE = CrontabSchedule.objects.create(
|
SCHEDULE = CrontabSchedule.objects.create(
|
||||||
minute=schedule['minute'],
|
minute=schedule["minute"],
|
||||||
hour=schedule['hour'],
|
hour=schedule["hour"],
|
||||||
day_of_week=schedule['day_of_week'],
|
day_of_week=schedule["day_of_week"],
|
||||||
day_of_month=schedule['day_of_month'],
|
day_of_month=schedule["day_of_month"],
|
||||||
month_of_year=schedule['month_of_year'],
|
month_of_year=schedule["month_of_year"],
|
||||||
timezone=schedule['timezone']
|
timezone=schedule["timezone"],
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f'Created Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute}')
|
f"Created Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f'Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute} already exists')
|
f"Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute} already exists"
|
||||||
for task in seed_data['scheduled_tasks']:
|
)
|
||||||
TASK = PeriodicTask.objects.filter(name=task['name']).first()
|
for task in seed_data["scheduled_tasks"]:
|
||||||
|
TASK = PeriodicTask.objects.filter(name=task["name"]).first()
|
||||||
if not TASK:
|
if not TASK:
|
||||||
if task['schedule']['type'] == 'crontab':
|
if task["schedule"]["type"] == "crontab":
|
||||||
SCHEDULE = CrontabSchedule.objects.filter(minute=task['schedule']['minute'],
|
SCHEDULE = CrontabSchedule.objects.filter(
|
||||||
hour=task['schedule']['hour'],
|
minute=task["schedule"]["minute"],
|
||||||
day_of_week=task['schedule']['day_of_week'],
|
hour=task["schedule"]["hour"],
|
||||||
day_of_month=task['schedule']['day_of_month'],
|
day_of_week=task["schedule"]["day_of_week"],
|
||||||
month_of_year=task['schedule']['month_of_year'],
|
day_of_month=task["schedule"]["day_of_month"],
|
||||||
timezone=task['schedule']['timezone']
|
month_of_year=task["schedule"]["month_of_year"],
|
||||||
).first()
|
timezone=task["schedule"]["timezone"],
|
||||||
|
).first()
|
||||||
TASK = PeriodicTask.objects.create(
|
TASK = PeriodicTask.objects.create(
|
||||||
crontab=SCHEDULE,
|
crontab=SCHEDULE,
|
||||||
name=task['name'],
|
name=task["name"],
|
||||||
task=task['task'],
|
task=task["task"],
|
||||||
enabled=task['enabled']
|
enabled=task["enabled"],
|
||||||
)
|
)
|
||||||
print(f'Created Periodic Task: {TASK.name}')
|
print(f"Created Periodic Task: {TASK.name}")
|
||||||
else:
|
else:
|
||||||
raise Exception('Schedule for Periodic Task not found')
|
raise Exception("Schedule for Periodic Task not found")
|
||||||
else:
|
else:
|
||||||
print(f'Periodic Task: {TASK.name} already exists')
|
print(f"Periodic Task: {TASK.name} already exists")
|
||||||
|
|
|
@ -4,20 +4,26 @@ from celery import shared_task
|
||||||
@shared_task
|
@shared_task
|
||||||
def get_paying_users():
|
def get_paying_users():
|
||||||
from subscriptions.models import UserSubscription
|
from subscriptions.models import UserSubscription
|
||||||
|
|
||||||
# Get a list of user subscriptions
|
# Get a list of user subscriptions
|
||||||
active_subscriptions = UserSubscription.objects.filter(
|
active_subscriptions = UserSubscription.objects.filter(valid=True).distinct("user")
|
||||||
valid=True).distinct('user')
|
|
||||||
|
|
||||||
# Get paying users
|
# Get paying users
|
||||||
active_users = []
|
active_users = []
|
||||||
|
|
||||||
# Paying regular users
|
# Paying regular users
|
||||||
active_users += [
|
active_users += [
|
||||||
subscription.user.id for subscription in active_subscriptions if subscription.user is not None and subscription.user.user_group is None]
|
subscription.user.id
|
||||||
|
for subscription in active_subscriptions
|
||||||
|
if subscription.user is not None and subscription.user.user_group is None
|
||||||
|
]
|
||||||
|
|
||||||
# Paying users within groups
|
# Paying users within groups
|
||||||
active_users += [
|
active_users += [
|
||||||
subscription.user_group.members for subscription in active_subscriptions if subscription.user_group is not None and subscription.user is None]
|
subscription.user_group.members
|
||||||
|
for subscription in active_subscriptions
|
||||||
|
if subscription.user_group is not None and subscription.user is None
|
||||||
|
]
|
||||||
|
|
||||||
# Return paying users
|
# Return paying users
|
||||||
return active_users
|
return active_users
|
||||||
|
|
|
@ -3,10 +3,10 @@ from rest_framework.routers import DefaultRouter
|
||||||
from accounts import views
|
from accounts import views
|
||||||
|
|
||||||
router = DefaultRouter()
|
router = DefaultRouter()
|
||||||
router.register(r'users', views.CustomUserViewSet, basename='users')
|
router.register(r"users", views.CustomUserViewSet, basename="users")
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('', include(router.urls)),
|
path("", include(router.urls)),
|
||||||
path('', include('djoser.urls')),
|
path("", include("djoser.urls")),
|
||||||
path('', include('djoser.urls.jwt')),
|
path("", include("djoser.urls.jwt")),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
import re
|
import re
|
||||||
|
@ -6,9 +5,10 @@ import re
|
||||||
|
|
||||||
class UppercaseValidator(object):
|
class UppercaseValidator(object):
|
||||||
def validate(self, password, user=None):
|
def validate(self, password, user=None):
|
||||||
if not re.findall('[A-Z]', password):
|
if not re.findall("[A-Z]", password):
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
_("The password must contain at least 1 uppercase letter (A-Z)."))
|
_("The password must contain at least 1 uppercase letter (A-Z).")
|
||||||
|
)
|
||||||
|
|
||||||
def get_help_text(self):
|
def get_help_text(self):
|
||||||
return _("Your password must contain at least 1 uppercase letter (A-Z).")
|
return _("Your password must contain at least 1 uppercase letter (A-Z).")
|
||||||
|
@ -16,9 +16,10 @@ class UppercaseValidator(object):
|
||||||
|
|
||||||
class LowercaseValidator(object):
|
class LowercaseValidator(object):
|
||||||
def validate(self, password, user=None):
|
def validate(self, password, user=None):
|
||||||
if not re.findall('[a-z]', password):
|
if not re.findall("[a-z]", password):
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
_("The password must contain at least 1 lowercase letter (a-z)."))
|
_("The password must contain at least 1 lowercase letter (a-z).")
|
||||||
|
)
|
||||||
|
|
||||||
def get_help_text(self):
|
def get_help_text(self):
|
||||||
return _("Your password must contain at least 1 lowercase letter (a-z).")
|
return _("Your password must contain at least 1 lowercase letter (a-z).")
|
||||||
|
@ -26,19 +27,25 @@ class LowercaseValidator(object):
|
||||||
|
|
||||||
class SpecialCharacterValidator(object):
|
class SpecialCharacterValidator(object):
|
||||||
def validate(self, password, user=None):
|
def validate(self, password, user=None):
|
||||||
if not re.findall('[@#$%^&*()_+/\<>;:!?]', password):
|
if not re.findall("[@#$%^&*()_+/\<>;:!?]", password):
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
_("The password must contain at least 1 special character (@, #, $, etc.)."))
|
_(
|
||||||
|
"The password must contain at least 1 special character (@, #, $, etc.)."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def get_help_text(self):
|
def get_help_text(self):
|
||||||
return _("Your password must contain at least 1 special character (@, #, $, etc.).")
|
return _(
|
||||||
|
"Your password must contain at least 1 special character (@, #, $, etc.)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NumberValidator(object):
|
class NumberValidator(object):
|
||||||
def validate(self, password, user=None):
|
def validate(self, password, user=None):
|
||||||
if not any(char.isdigit() for char in password):
|
if not any(char.isdigit() for char in password):
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
_("The password must contain at least one numerical digit (0-9)."))
|
_("The password must contain at least one numerical digit (0-9).")
|
||||||
|
)
|
||||||
|
|
||||||
def get_help_text(self):
|
def get_help_text(self):
|
||||||
return _("Your password must contain at least numerical digit (0-9).")
|
return _("Your password must contain at least numerical digit (0-9).")
|
||||||
|
|
|
@ -22,28 +22,27 @@ class CustomUserViewSet(DjoserUserViewSet):
|
||||||
user = self.request.user
|
user = self.request.user
|
||||||
# If user is admin, show all active users
|
# If user is admin, show all active users
|
||||||
if user.is_superuser:
|
if user.is_superuser:
|
||||||
key = 'users'
|
key = "users"
|
||||||
# Get cache
|
# Get cache
|
||||||
queryset = cache.get(key)
|
queryset = cache.get(key)
|
||||||
# Set cache if stale or does not exist
|
# Set cache if stale or does not exist
|
||||||
if not queryset:
|
if not queryset:
|
||||||
queryset = CustomUser.objects.filter(is_active=True)
|
queryset = CustomUser.objects.filter(is_active=True)
|
||||||
cache.set(key, queryset, 60*60)
|
cache.set(key, queryset, 60 * 60)
|
||||||
return queryset
|
return queryset
|
||||||
elif not user.user_group:
|
elif not user.user_group:
|
||||||
key = f'user:{user.id}'
|
key = f"user:{user.id}"
|
||||||
queryset = cache.get(key)
|
queryset = cache.get(key)
|
||||||
if not queryset:
|
if not queryset:
|
||||||
queryset = CustomUser.objects.filter(is_active=True)
|
queryset = CustomUser.objects.filter(is_active=True)
|
||||||
cache.set(key, queryset, 60*60)
|
cache.set(key, queryset, 60 * 60)
|
||||||
return queryset
|
return queryset
|
||||||
elif user.user_group:
|
elif user.user_group:
|
||||||
key = f'usergroup_users:{user.user_group.id}'
|
key = f"usergroup_users:{user.user_group.id}"
|
||||||
queryset = cache.get(key)
|
queryset = cache.get(key)
|
||||||
if not queryset:
|
if not queryset:
|
||||||
queryset = CustomUser.objects.filter(
|
queryset = CustomUser.objects.filter(user_group=user.user_group)
|
||||||
user_group=user.user_group)
|
cache.set(key, queryset, 60 * 60)
|
||||||
cache.set(key, queryset, 60*60)
|
|
||||||
return queryset
|
return queryset
|
||||||
else:
|
else:
|
||||||
return CustomUser.objects.none()
|
return CustomUser.objects.none()
|
||||||
|
@ -52,10 +51,10 @@ class CustomUserViewSet(DjoserUserViewSet):
|
||||||
user = self.request.user
|
user = self.request.user
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'users')
|
cache.delete(f"users")
|
||||||
cache.delete(f'user:{user.id}')
|
cache.delete(f"user:{user.id}")
|
||||||
if user.user_group:
|
if user.user_group:
|
||||||
cache.delete(f'usergroup_users:{user.user_group.id}')
|
cache.delete(f"usergroup_users:{user.user_group.id}")
|
||||||
|
|
||||||
super().perform_update(serializer, *args, **kwargs)
|
super().perform_update(serializer, *args, **kwargs)
|
||||||
user = serializer.instance
|
user = serializer.instance
|
||||||
|
@ -84,16 +83,18 @@ class CustomUserViewSet(DjoserUserViewSet):
|
||||||
settings.EMAIL.confirmation(self.request, context).send(to)
|
settings.EMAIL.confirmation(self.request, context).send(to)
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete('users')
|
cache.delete("users")
|
||||||
cache.delete(f'user:{user.id}')
|
cache.delete(f"user:{user.id}")
|
||||||
if user.user_group:
|
if user.user_group:
|
||||||
cache.delete(f'usergroup_users:{user.user_group.id}')
|
cache.delete(f"usergroup_users:{user.user_group.id}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('Warning: Unable to send email')
|
print("Warning: Unable to send email")
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
@action(methods=['post'], detail=False, url_path='activation', url_name='activation')
|
@action(
|
||||||
|
methods=["post"], detail=False, url_path="activation", url_name="activation"
|
||||||
|
)
|
||||||
def activation(self, request, *args, **kwargs):
|
def activation(self, request, *args, **kwargs):
|
||||||
serializer = self.get_serializer(data=request.data)
|
serializer = self.get_serializer(data=request.data)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
|
@ -103,16 +104,16 @@ class CustomUserViewSet(DjoserUserViewSet):
|
||||||
|
|
||||||
# Construct a response with user's first name, last name, and email
|
# Construct a response with user's first name, last name, and email
|
||||||
user_data = {
|
user_data = {
|
||||||
'first_name': user.first_name,
|
"first_name": user.first_name,
|
||||||
'last_name': user.last_name,
|
"last_name": user.last_name,
|
||||||
'email': user.email,
|
"email": user.email,
|
||||||
'username': user.username
|
"username": user.username,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete('users')
|
cache.delete("users")
|
||||||
cache.delete(f'user:{user.id}')
|
cache.delete(f"user:{user.id}")
|
||||||
if user.user_group:
|
if user.user_group:
|
||||||
cache.delete(f'usergroup_users:{user.user_group.id}')
|
cache.delete(f"usergroup_users:{user.user_group.id}")
|
||||||
|
|
||||||
return Response(user_data, status=status.HTTP_200_OK)
|
return Response(user_data, status=status.HTTP_200_OK)
|
||||||
|
|
|
@ -1,27 +1,31 @@
|
||||||
from django.conf.urls.static import static
|
from django.conf.urls.static import static
|
||||||
from django.contrib.staticfiles.urls import staticfiles_urlpatterns
|
from django.contrib.staticfiles.urls import staticfiles_urlpatterns
|
||||||
from django.urls import path, include
|
from django.urls import path, include
|
||||||
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
|
from drf_spectacular.views import (
|
||||||
|
SpectacularAPIView,
|
||||||
|
SpectacularRedocView,
|
||||||
|
SpectacularSwaggerView,
|
||||||
|
)
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from config.settings import DEBUG, SERVE_MEDIA, MEDIA_ROOT
|
from config.settings import DEBUG, SERVE_MEDIA, MEDIA_ROOT
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('accounts/', include('accounts.urls')),
|
path("accounts/", include("accounts.urls")),
|
||||||
path('subscriptions/', include('subscriptions.urls')),
|
path("subscriptions/", include("subscriptions.urls")),
|
||||||
path('notifications/', include('notifications.urls')),
|
path("notifications/", include("notifications.urls")),
|
||||||
path('billing/', include('billing.urls')),
|
path("billing/", include("billing.urls")),
|
||||||
path('stripe/', include('payments.urls')),
|
path("stripe/", include("payments.urls")),
|
||||||
path('admin/', admin.site.urls),
|
path("admin/", admin.site.urls),
|
||||||
path('schema/', SpectacularAPIView.as_view(), name='schema'),
|
path("schema/", SpectacularAPIView.as_view(), name="schema"),
|
||||||
path('swagger/',
|
path(
|
||||||
SpectacularSwaggerView.as_view(url_name='schema'), name='swagger-ui'),
|
"swagger/", SpectacularSwaggerView.as_view(url_name="schema"), name="swagger-ui"
|
||||||
path('redoc/',
|
),
|
||||||
SpectacularRedocView.as_view(url_name='schema'), name='redoc'),
|
path("redoc/", SpectacularRedocView.as_view(url_name="schema"), name="redoc"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# URLs for local development
|
# URLs for local development
|
||||||
if DEBUG and SERVE_MEDIA:
|
if DEBUG and SERVE_MEDIA:
|
||||||
urlpatterns += staticfiles_urlpatterns()
|
urlpatterns += staticfiles_urlpatterns()
|
||||||
urlpatterns += static(
|
urlpatterns += static("media/", document_root=MEDIA_ROOT)
|
||||||
'media/', document_root=MEDIA_ROOT)
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
urlpatterns += [path('silk/', include('silk.urls', namespace='silk'))]
|
urlpatterns += [path("silk/", include("silk.urls", namespace="silk"))]
|
||||||
|
|
|
@ -2,6 +2,5 @@ from django.urls import path
|
||||||
from billing import views
|
from billing import views
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('',
|
path("", views.BillingHistoryView.as_view()),
|
||||||
views.BillingHistoryView.as_view()),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -24,7 +24,7 @@ class BillingHistoryView(APIView):
|
||||||
email = requesting_user.email
|
email = requesting_user.email
|
||||||
|
|
||||||
# Check cache
|
# Check cache
|
||||||
key = f'billing_user:{requesting_user.id}'
|
key = f"billing_user:{requesting_user.id}"
|
||||||
billing_history = cache.get(key)
|
billing_history = cache.get(key)
|
||||||
|
|
||||||
if not billing_history:
|
if not billing_history:
|
||||||
|
@ -39,23 +39,25 @@ class BillingHistoryView(APIView):
|
||||||
|
|
||||||
if len(customers.data) > 0:
|
if len(customers.data) > 0:
|
||||||
# Retrieve the customer's charges (billing history)
|
# Retrieve the customer's charges (billing history)
|
||||||
charges = stripe.Charge.list(
|
charges = stripe.Charge.list(limit=10, customer=customer.id)
|
||||||
limit=10, customer=customer.id)
|
|
||||||
|
|
||||||
# Prepare the response
|
# Prepare the response
|
||||||
billing_history = [
|
billing_history = [
|
||||||
{
|
{
|
||||||
'email': charge['billing_details']['email'],
|
"email": charge["billing_details"]["email"],
|
||||||
'amount_charged': int(charge['amount']/100),
|
"amount_charged": int(charge["amount"] / 100),
|
||||||
'paid': charge['paid'],
|
"paid": charge["paid"],
|
||||||
'refunded': int(charge['amount_refunded']/100) > 0,
|
"refunded": int(charge["amount_refunded"] / 100) > 0,
|
||||||
'amount_refunded': int(charge['amount_refunded']/100),
|
"amount_refunded": int(charge["amount_refunded"] / 100),
|
||||||
'last_4': charge['payment_method_details']['card']['last4'],
|
"last_4": charge["payment_method_details"]["card"]["last4"],
|
||||||
'receipt_link': charge['receipt_url'],
|
"receipt_link": charge["receipt_url"],
|
||||||
'timestamp': datetime.fromtimestamp(charge['created']).strftime("%m-%d-%Y %I:%M %p"),
|
"timestamp": datetime.fromtimestamp(
|
||||||
} for charge in charges.auto_paging_iter()
|
charge["created"]
|
||||||
|
).strftime("%m-%d-%Y %I:%M %p"),
|
||||||
|
}
|
||||||
|
for charge in charges.auto_paging_iter()
|
||||||
]
|
]
|
||||||
|
|
||||||
cache.set(key, billing_history, 60*60)
|
cache.set(key, billing_history, 60 * 60)
|
||||||
|
|
||||||
return Response(billing_history, status=status.HTTP_200_OK)
|
return Response(billing_history, status=status.HTTP_200_OK)
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from .celery import app as celery_app
|
from .celery import app as celery_app
|
||||||
|
|
||||||
__all__ = ('celery_app',)
|
__all__ = ("celery_app",)
|
||||||
|
|
|
@ -11,6 +11,6 @@ import os
|
||||||
|
|
||||||
from django.core.asgi import get_asgi_application
|
from django.core.asgi import get_asgi_application
|
||||||
|
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
|
||||||
|
|
||||||
application = get_asgi_application()
|
application = get_asgi_application()
|
||||||
|
|
|
@ -3,15 +3,15 @@ import os
|
||||||
|
|
||||||
|
|
||||||
# Set the default Django settings module for the 'celery' program.
|
# Set the default Django settings module for the 'celery' program.
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
|
||||||
|
|
||||||
app = Celery('config')
|
app = Celery("config")
|
||||||
|
|
||||||
# Using a string here means the worker doesn't have to serialize
|
# Using a string here means the worker doesn't have to serialize
|
||||||
# the configuration object to child processes.
|
# the configuration object to child processes.
|
||||||
# - namespace='CELERY' means all celery-related configuration keys
|
# - namespace='CELERY' means all celery-related configuration keys
|
||||||
# should have a `CELERY_` prefix.
|
# should have a `CELERY_` prefix.
|
||||||
app.config_from_object('django.conf:settings', namespace='CELERY')
|
app.config_from_object("django.conf:settings", namespace="CELERY")
|
||||||
|
|
||||||
# Load task modules from all registered Django apps.
|
# Load task modules from all registered Django apps.
|
||||||
app.autodiscover_tasks()
|
app.autodiscover_tasks()
|
||||||
|
|
|
@ -10,9 +10,9 @@ BASE_DIR = Path(__file__).resolve().parent.parent
|
||||||
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
|
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
|
||||||
|
|
||||||
# If you're hosting this with a secret provider, have this set to True
|
# If you're hosting this with a secret provider, have this set to True
|
||||||
USE_VAULT = bool(os.getenv('USE_VAULT', False) == 'True')
|
USE_VAULT = bool(os.getenv("USE_VAULT", False) == "True")
|
||||||
# Have this set to True to serve media and static contents directly via Django
|
# Have this set to True to serve media and static contents directly via Django
|
||||||
SERVE_MEDIA = bool(os.getenv('SERVE_MEDIA', False) == 'True')
|
SERVE_MEDIA = bool(os.getenv("SERVE_MEDIA", False) == "True")
|
||||||
|
|
||||||
load_dotenv(find_dotenv())
|
load_dotenv(find_dotenv())
|
||||||
|
|
||||||
|
@ -35,98 +35,97 @@ def get_secret(secret_name):
|
||||||
|
|
||||||
|
|
||||||
# URL Prefixes
|
# URL Prefixes
|
||||||
URL_SCHEME = 'https' if (get_secret('USE_HTTPS') == 'True') else 'http'
|
URL_SCHEME = "https" if (get_secret("USE_HTTPS") == "True") else "http"
|
||||||
# Backend
|
# Backend
|
||||||
BACKEND_ADDRESS = get_secret('BACKEND_ADDRESS')
|
BACKEND_ADDRESS = get_secret("BACKEND_ADDRESS")
|
||||||
BACKEND_PORT = get_secret('BACKEND_PORT')
|
BACKEND_PORT = get_secret("BACKEND_PORT")
|
||||||
# Frontend
|
# Frontend
|
||||||
FRONTEND_ADDRESS = get_secret('FRONTEND_ADDRESS')
|
FRONTEND_ADDRESS = get_secret("FRONTEND_ADDRESS")
|
||||||
FRONTEND_PORT = get_secret('FRONTEND_PORT')
|
FRONTEND_PORT = get_secret("FRONTEND_PORT")
|
||||||
|
|
||||||
ALLOWED_HOSTS = ['*']
|
ALLOWED_HOSTS = ["*"]
|
||||||
CSRF_TRUSTED_ORIGINS = [
|
CSRF_TRUSTED_ORIGINS = [
|
||||||
# Frontend
|
# Frontend
|
||||||
f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}',
|
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}",
|
||||||
f'{URL_SCHEME}://{FRONTEND_ADDRESS}', # For external domains
|
f"{URL_SCHEME}://{FRONTEND_ADDRESS}", # For external domains
|
||||||
# Backend
|
# Backend
|
||||||
f'{URL_SCHEME}://{BACKEND_ADDRESS}:{BACKEND_PORT}',
|
f"{URL_SCHEME}://{BACKEND_ADDRESS}:{BACKEND_PORT}",
|
||||||
f'{URL_SCHEME}://{BACKEND_ADDRESS}' # For external domains
|
f"{URL_SCHEME}://{BACKEND_ADDRESS}", # For external domains
|
||||||
# You can also set up https://*.name.xyz for wildcards here
|
# You can also set up https://*.name.xyz for wildcards here
|
||||||
]
|
]
|
||||||
|
|
||||||
# SECURITY WARNING: don't run with debug turned on in production!
|
# SECURITY WARNING: don't run with debug turned on in production!
|
||||||
DEBUG = (get_secret('BACKEND_DEBUG') == 'True')
|
DEBUG = get_secret("BACKEND_DEBUG") == "True"
|
||||||
# Determines whether or not to insert test data within tables
|
# Determines whether or not to insert test data within tables
|
||||||
SEED_DATA = (get_secret('SEED_DATA') == 'True')
|
SEED_DATA = get_secret("SEED_DATA") == "True"
|
||||||
# SECURITY WARNING: keep the secret key used in production secret!
|
# SECURITY WARNING: keep the secret key used in production secret!
|
||||||
SECRET_KEY = get_secret('SECRET_KEY')
|
SECRET_KEY = get_secret("SECRET_KEY")
|
||||||
# Selenium Config
|
# Selenium Config
|
||||||
# Initiate CAPTCHA solver in test mode
|
# Initiate CAPTCHA solver in test mode
|
||||||
CAPTCHA_TESTING = (get_secret('CAPTCHA_TESTING') == 'True')
|
CAPTCHA_TESTING = get_secret("CAPTCHA_TESTING") == "True"
|
||||||
# If using Selenium and/or the provided CAPTCHA solver, determines whether or not to use proxies
|
# If using Selenium and/or the provided CAPTCHA solver, determines whether or not to use proxies
|
||||||
USE_PROXY = (get_secret('USE_PROXY') == 'True')
|
USE_PROXY = get_secret("USE_PROXY") == "True"
|
||||||
|
|
||||||
# Stripe (For payments)
|
# Stripe (For payments)
|
||||||
STRIPE_SECRET_KEY = get_secret(
|
STRIPE_SECRET_KEY = get_secret("STRIPE_SECRET_KEY")
|
||||||
"STRIPE_SECRET_KEY")
|
STRIPE_SECRET_WEBHOOK = get_secret("STRIPE_SECRET_WEBHOOK")
|
||||||
STRIPE_SECRET_WEBHOOK = get_secret('STRIPE_SECRET_WEBHOOK')
|
STRIPE_CHECKOUT_URL = f""
|
||||||
STRIPE_CHECKOUT_URL = f''
|
|
||||||
|
|
||||||
# Email credentials
|
# Email credentials
|
||||||
EMAIL_HOST = get_secret('EMAIL_HOST')
|
EMAIL_HOST = get_secret("EMAIL_HOST")
|
||||||
EMAIL_HOST_USER = get_secret('EMAIL_HOST_USER')
|
EMAIL_HOST_USER = get_secret("EMAIL_HOST_USER")
|
||||||
EMAIL_HOST_PASSWORD = get_secret('EMAIL_HOST_PASSWORD')
|
EMAIL_HOST_PASSWORD = get_secret("EMAIL_HOST_PASSWORD")
|
||||||
EMAIL_PORT = get_secret('EMAIL_PORT')
|
EMAIL_PORT = get_secret("EMAIL_PORT")
|
||||||
EMAIL_USE_TLS = (get_secret('EMAIL_USE_TLS') == 'True')
|
EMAIL_USE_TLS = get_secret("EMAIL_USE_TLS") == "True"
|
||||||
EMAIL_ADDRESS = (get_secret('EMAIL_ADDRESS') == 'True')
|
EMAIL_ADDRESS = get_secret("EMAIL_ADDRESS") == "True"
|
||||||
|
|
||||||
# Application definition
|
# Application definition
|
||||||
|
|
||||||
INSTALLED_APPS = [
|
INSTALLED_APPS = [
|
||||||
'config',
|
"config",
|
||||||
'unfold',
|
"unfold",
|
||||||
'unfold.contrib.filters',
|
"unfold.contrib.filters",
|
||||||
'unfold.contrib.simple_history',
|
"unfold.contrib.simple_history",
|
||||||
'django.contrib.admin',
|
"django.contrib.admin",
|
||||||
'django.contrib.auth',
|
"django.contrib.auth",
|
||||||
'django.contrib.contenttypes',
|
"django.contrib.contenttypes",
|
||||||
'django.contrib.sessions',
|
"django.contrib.sessions",
|
||||||
'django.contrib.messages',
|
"django.contrib.messages",
|
||||||
'django.contrib.staticfiles',
|
"django.contrib.staticfiles",
|
||||||
'storages',
|
"storages",
|
||||||
'django_extensions',
|
"django_extensions",
|
||||||
'rest_framework',
|
"rest_framework",
|
||||||
'rest_framework_simplejwt',
|
"rest_framework_simplejwt",
|
||||||
'django_celery_results',
|
"django_celery_results",
|
||||||
'django_celery_beat',
|
"django_celery_beat",
|
||||||
'simple_history',
|
"simple_history",
|
||||||
'djoser',
|
"djoser",
|
||||||
'corsheaders',
|
"corsheaders",
|
||||||
'drf_spectacular',
|
"drf_spectacular",
|
||||||
'drf_spectacular_sidecar',
|
"drf_spectacular_sidecar",
|
||||||
'webdriver',
|
"webdriver",
|
||||||
'accounts',
|
"accounts",
|
||||||
'user_groups',
|
"user_groups",
|
||||||
'subscriptions',
|
"subscriptions",
|
||||||
'payments',
|
"payments",
|
||||||
'billing',
|
"billing",
|
||||||
'emails',
|
"emails",
|
||||||
'notifications',
|
"notifications",
|
||||||
'search_results'
|
"search_results",
|
||||||
]
|
]
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
INSTALLED_APPS += ['silk']
|
INSTALLED_APPS += ["silk"]
|
||||||
MIDDLEWARE = [
|
MIDDLEWARE = [
|
||||||
'django.middleware.security.SecurityMiddleware',
|
"django.middleware.security.SecurityMiddleware",
|
||||||
"silk.middleware.SilkyMiddleware",
|
"silk.middleware.SilkyMiddleware",
|
||||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||||
"corsheaders.middleware.CorsMiddleware",
|
"corsheaders.middleware.CorsMiddleware",
|
||||||
'django.middleware.common.CommonMiddleware',
|
"django.middleware.common.CommonMiddleware",
|
||||||
'django.middleware.csrf.CsrfViewMiddleware',
|
"django.middleware.csrf.CsrfViewMiddleware",
|
||||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
'django.contrib.messages.middleware.MessageMiddleware',
|
"django.contrib.messages.middleware.MessageMiddleware",
|
||||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||||
]
|
]
|
||||||
DJANGO_LOG_LEVEL = "DEBUG"
|
DJANGO_LOG_LEVEL = "DEBUG"
|
||||||
# Enables VS Code debugger to break on raised exceptions
|
# Enables VS Code debugger to break on raised exceptions
|
||||||
|
@ -146,111 +145,101 @@ if DEBUG:
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
MIDDLEWARE = [
|
MIDDLEWARE = [
|
||||||
'django.middleware.security.SecurityMiddleware',
|
"django.middleware.security.SecurityMiddleware",
|
||||||
"whitenoise.middleware.WhiteNoiseMiddleware",
|
"whitenoise.middleware.WhiteNoiseMiddleware",
|
||||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||||
"corsheaders.middleware.CorsMiddleware",
|
"corsheaders.middleware.CorsMiddleware",
|
||||||
'django.middleware.common.CommonMiddleware',
|
"django.middleware.common.CommonMiddleware",
|
||||||
'django.middleware.csrf.CsrfViewMiddleware',
|
"django.middleware.csrf.CsrfViewMiddleware",
|
||||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
'django.contrib.messages.middleware.MessageMiddleware',
|
"django.contrib.messages.middleware.MessageMiddleware",
|
||||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Static files (CSS, JavaScript, Images)
|
# Static files (CSS, JavaScript, Images)
|
||||||
# https://docs.djangoproject.com/en/4.2/howto/static-files/
|
# https://docs.djangoproject.com/en/4.2/howto/static-files/
|
||||||
|
|
||||||
ROOT_URLCONF = 'config.urls'
|
ROOT_URLCONF = "config.urls"
|
||||||
if SERVE_MEDIA:
|
if SERVE_MEDIA:
|
||||||
# Cloud Storage Settings
|
# Cloud Storage Settings
|
||||||
# This is assuming you use the same bucket for media and static containers
|
# This is assuming you use the same bucket for media and static containers
|
||||||
CLOUD_BUCKET = get_secret('CLOUD_BUCKET')
|
CLOUD_BUCKET = get_secret("CLOUD_BUCKET")
|
||||||
MEDIA_CONTAINER = get_secret('MEDIA_CONTAINER')
|
MEDIA_CONTAINER = get_secret("MEDIA_CONTAINER")
|
||||||
STATIC_CONTAINER = get_secret('STATIC_CONTAINER')
|
STATIC_CONTAINER = get_secret("STATIC_CONTAINER")
|
||||||
|
|
||||||
MEDIA_URL = f'https://{CLOUD_BUCKET}/{MEDIA_CONTAINER}/'
|
MEDIA_URL = f"https://{CLOUD_BUCKET}/{MEDIA_CONTAINER}/"
|
||||||
MEDIA_ROOT = f'https://{CLOUD_BUCKET}/'
|
MEDIA_ROOT = f"https://{CLOUD_BUCKET}/"
|
||||||
|
|
||||||
STATIC_URL = f'https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/'
|
STATIC_URL = f"https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/"
|
||||||
STATIC_ROOT = f'https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/'
|
STATIC_ROOT = f"https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/"
|
||||||
|
|
||||||
# Consult django-storages documentation when filling in these values. This will vary depending on your cloud service provider
|
# Consult django-storages documentation when filling in these values. This will vary depending on your cloud service provider
|
||||||
STORAGES = {
|
STORAGES = {
|
||||||
'default': {
|
"default": {
|
||||||
# TODO: Set this up here if you're using cloud storage
|
# TODO: Set this up here if you're using cloud storage
|
||||||
'BACKEND': None,
|
"BACKEND": None,
|
||||||
'OPTIONS': {
|
"OPTIONS": {
|
||||||
# Optional parameters
|
# Optional parameters
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
'staticfiles': {
|
"staticfiles": {
|
||||||
# TODO: Set this up here if you're using cloud storage
|
# TODO: Set this up here if you're using cloud storage
|
||||||
'BACKEND': None,
|
"BACKEND": None,
|
||||||
'OPTIONS': {
|
"OPTIONS": {
|
||||||
# Optional parameters
|
# Optional parameters
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
STATIC_URL = 'static/'
|
STATIC_URL = "static/"
|
||||||
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
|
STATIC_ROOT = os.path.join(BASE_DIR, "static")
|
||||||
STATICFILES_STORAGE = "whitenoise.storage.CompressedManifestStaticFilesStorage"
|
STATICFILES_STORAGE = "whitenoise.storage.CompressedManifestStaticFilesStorage"
|
||||||
MEDIA_URL = 'api/v1/media/'
|
MEDIA_URL = "api/v1/media/"
|
||||||
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
|
MEDIA_ROOT = os.path.join(BASE_DIR, "media")
|
||||||
ROOT_URLCONF = 'config.urls'
|
ROOT_URLCONF = "config.urls"
|
||||||
|
|
||||||
TEMPLATES = [
|
TEMPLATES = [
|
||||||
{
|
{
|
||||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||||
'DIRS': [
|
"DIRS": [
|
||||||
BASE_DIR / 'emails/templates/',
|
BASE_DIR / "emails/templates/",
|
||||||
],
|
],
|
||||||
'APP_DIRS': True,
|
"APP_DIRS": True,
|
||||||
'OPTIONS': {
|
"OPTIONS": {
|
||||||
'context_processors': [
|
"context_processors": [
|
||||||
'django.template.context_processors.debug',
|
"django.template.context_processors.debug",
|
||||||
'django.template.context_processors.request',
|
"django.template.context_processors.request",
|
||||||
'django.contrib.auth.context_processors.auth',
|
"django.contrib.auth.context_processors.auth",
|
||||||
'django.contrib.messages.context_processors.messages',
|
"django.contrib.messages.context_processors.messages",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
REST_FRAMEWORK = {
|
REST_FRAMEWORK = {
|
||||||
'DEFAULT_AUTHENTICATION_CLASSES': (
|
"DEFAULT_AUTHENTICATION_CLASSES": (
|
||||||
'rest_framework_simplejwt.authentication.JWTAuthentication',
|
"rest_framework_simplejwt.authentication.JWTAuthentication",
|
||||||
),
|
),
|
||||||
'DEFAULT_THROTTLE_CLASSES': [
|
"DEFAULT_THROTTLE_CLASSES": [
|
||||||
|
"rest_framework.throttling.AnonRateThrottle",
|
||||||
'rest_framework.throttling.AnonRateThrottle',
|
"rest_framework.throttling.UserRateThrottle",
|
||||||
|
|
||||||
'rest_framework.throttling.UserRateThrottle'
|
|
||||||
|
|
||||||
],
|
],
|
||||||
|
"DEFAULT_THROTTLE_RATES": {"anon": "360/min", "user": "1440/min"},
|
||||||
'DEFAULT_THROTTLE_RATES': {
|
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
|
||||||
|
|
||||||
'anon': '360/min',
|
|
||||||
|
|
||||||
'user': '1440/min'
|
|
||||||
|
|
||||||
},
|
|
||||||
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# DRF-Spectacular
|
# DRF-Spectacular
|
||||||
SPECTACULAR_SETTINGS = {
|
SPECTACULAR_SETTINGS = {
|
||||||
'TITLE': 'DRF-Template',
|
"TITLE": "DRF-Template",
|
||||||
'DESCRIPTION': 'A Template Project by Keannu Bernasol',
|
"DESCRIPTION": "A Template Project by Keannu Bernasol",
|
||||||
'VERSION': '1.0.0',
|
"VERSION": "1.0.0",
|
||||||
'SERVE_INCLUDE_SCHEMA': False,
|
"SERVE_INCLUDE_SCHEMA": False,
|
||||||
'SWAGGER_UI_DIST': 'SIDECAR',
|
"SWAGGER_UI_DIST": "SIDECAR",
|
||||||
'SWAGGER_UI_FAVICON_HREF': 'SIDECAR',
|
"SWAGGER_UI_FAVICON_HREF": "SIDECAR",
|
||||||
'REDOC_DIST': 'SIDECAR',
|
"REDOC_DIST": "SIDECAR",
|
||||||
}
|
}
|
||||||
|
|
||||||
WSGI_APPLICATION = 'config.wsgi.application'
|
WSGI_APPLICATION = "config.wsgi.application"
|
||||||
|
|
||||||
# If you're using an external connection bouncer (eg. PgBouncer), server side cursors must be disabled to avoid any issues
|
# If you're using an external connection bouncer (eg. PgBouncer), server side cursors must be disabled to avoid any issues
|
||||||
USE_BOUNCER = get_secret("USE_BOUNCER")
|
USE_BOUNCER = get_secret("USE_BOUNCER")
|
||||||
|
@ -266,15 +255,13 @@ else:
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
"default": {
|
"default": {
|
||||||
"ENGINE": "django.db.backends.postgresql",
|
"ENGINE": "django.db.backends.postgresql",
|
||||||
'DISABLE_SERVER_SIDE_CURSORS': DISABLE_SERVER_SIDE_CURSORS,
|
"DISABLE_SERVER_SIDE_CURSORS": DISABLE_SERVER_SIDE_CURSORS,
|
||||||
"NAME": get_secret("DB_DATABASE"),
|
"NAME": get_secret("DB_DATABASE"),
|
||||||
"USER": get_secret("DB_USERNAME"),
|
"USER": get_secret("DB_USERNAME"),
|
||||||
"PASSWORD": get_secret("DB_PASSWORD"),
|
"PASSWORD": get_secret("DB_PASSWORD"),
|
||||||
"HOST": DB_HOST,
|
"HOST": DB_HOST,
|
||||||
"PORT": DB_PORT,
|
"PORT": DB_PORT,
|
||||||
"OPTIONS": {
|
"OPTIONS": {"sslmode": get_secret("DB_SSL_MODE")},
|
||||||
"sslmode": get_secret("DB_SSL_MODE")
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# Django Cache
|
# Django Cache
|
||||||
|
@ -284,34 +271,34 @@ CACHES = {
|
||||||
"LOCATION": f"redis://{get_secret('REDIS_HOST')}:{get_secret('REDIS_PORT')}/2",
|
"LOCATION": f"redis://{get_secret('REDIS_HOST')}:{get_secret('REDIS_PORT')}/2",
|
||||||
"OPTIONS": {
|
"OPTIONS": {
|
||||||
"CLIENT_CLASS": "django_redis.client.DefaultClient",
|
"CLIENT_CLASS": "django_redis.client.DefaultClient",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AUTH_USER_MODEL = 'accounts.CustomUser'
|
AUTH_USER_MODEL = "accounts.CustomUser"
|
||||||
|
|
||||||
DJOSER = {
|
DJOSER = {
|
||||||
'SEND_ACTIVATION_EMAIL': True,
|
"SEND_ACTIVATION_EMAIL": True,
|
||||||
'SEND_CONFIRMATION_EMAIL': True,
|
"SEND_CONFIRMATION_EMAIL": True,
|
||||||
'PASSWORD_RESET_CONFIRM_URL': 'reset_password_confirm/{uid}/{token}',
|
"PASSWORD_RESET_CONFIRM_URL": "reset_password_confirm/{uid}/{token}",
|
||||||
'ACTIVATION_URL': 'activation/{uid}/{token}',
|
"ACTIVATION_URL": "activation/{uid}/{token}",
|
||||||
'USER_AUTHENTICATION_RULES': ['djoser.authentication.TokenAuthenticationRule'],
|
"USER_AUTHENTICATION_RULES": ["djoser.authentication.TokenAuthenticationRule"],
|
||||||
'EMAIL': {
|
"EMAIL": {
|
||||||
'activation': 'emails.templates.ActivationEmail',
|
"activation": "emails.templates.ActivationEmail",
|
||||||
'password_reset': 'emails.templates.PasswordResetEmail'
|
"password_reset": "emails.templates.PasswordResetEmail",
|
||||||
},
|
},
|
||||||
'SERIALIZERS': {
|
"SERIALIZERS": {
|
||||||
'user': 'accounts.serializers.CustomUserSerializer',
|
"user": "accounts.serializers.CustomUserSerializer",
|
||||||
'current_user': 'accounts.serializers.CustomUserSerializer',
|
"current_user": "accounts.serializers.CustomUserSerializer",
|
||||||
'user_create': 'accounts.serializers.UserRegistrationSerializer',
|
"user_create": "accounts.serializers.UserRegistrationSerializer",
|
||||||
},
|
},
|
||||||
'PERMISSIONS': {
|
"PERMISSIONS": {
|
||||||
# Disable some unneeded endpoints by setting them to admin only
|
# Disable some unneeded endpoints by setting them to admin only
|
||||||
'username_reset': ['rest_framework.permissions.IsAdminUser'],
|
"username_reset": ["rest_framework.permissions.IsAdminUser"],
|
||||||
'username_reset_confirm': ['rest_framework.permissions.IsAdminUser'],
|
"username_reset_confirm": ["rest_framework.permissions.IsAdminUser"],
|
||||||
'set_username': ['rest_framework.permissions.IsAdminUser'],
|
"set_username": ["rest_framework.permissions.IsAdminUser"],
|
||||||
'set_password': ['rest_framework.permissions.IsAdminUser'],
|
"set_password": ["rest_framework.permissions.IsAdminUser"],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Password validation
|
# Password validation
|
||||||
|
@ -319,32 +306,32 @@ DJOSER = {
|
||||||
|
|
||||||
AUTH_PASSWORD_VALIDATORS = [
|
AUTH_PASSWORD_VALIDATORS = [
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
|
||||||
"OPTIONS": {
|
"OPTIONS": {
|
||||||
"min_length": 8,
|
"min_length": 8,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
|
||||||
},
|
},
|
||||||
# Additional password validators
|
# Additional password validators
|
||||||
{
|
{
|
||||||
'NAME': 'accounts.validators.SpecialCharacterValidator',
|
"NAME": "accounts.validators.SpecialCharacterValidator",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'accounts.validators.LowercaseValidator',
|
"NAME": "accounts.validators.LowercaseValidator",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'accounts.validators.UppercaseValidator',
|
"NAME": "accounts.validators.UppercaseValidator",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'accounts.validators.NumberValidator',
|
"NAME": "accounts.validators.NumberValidator",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -352,9 +339,9 @@ AUTH_PASSWORD_VALIDATORS = [
|
||||||
# Internationalization
|
# Internationalization
|
||||||
# https://docs.djangoproject.com/en/4.2/topics/i18n/
|
# https://docs.djangoproject.com/en/4.2/topics/i18n/
|
||||||
|
|
||||||
LANGUAGE_CODE = 'en-us'
|
LANGUAGE_CODE = "en-us"
|
||||||
|
|
||||||
TIME_ZONE = get_secret('TIMEZONE')
|
TIME_ZONE = get_secret("TIMEZONE")
|
||||||
|
|
||||||
USE_I18N = True
|
USE_I18N = True
|
||||||
|
|
||||||
|
@ -364,14 +351,14 @@ USE_TZ = True
|
||||||
# Default primary key field type
|
# Default primary key field type
|
||||||
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
|
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
|
||||||
|
|
||||||
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
||||||
|
|
||||||
SITE_NAME = 'DRF-Template'
|
SITE_NAME = "DRF-Template"
|
||||||
|
|
||||||
# JWT Token Lifetimes
|
# JWT Token Lifetimes
|
||||||
SIMPLE_JWT = {
|
SIMPLE_JWT = {
|
||||||
"ACCESS_TOKEN_LIFETIME": timedelta(hours=1),
|
"ACCESS_TOKEN_LIFETIME": timedelta(hours=1),
|
||||||
"REFRESH_TOKEN_LIFETIME": timedelta(days=3)
|
"REFRESH_TOKEN_LIFETIME": timedelta(days=3),
|
||||||
}
|
}
|
||||||
|
|
||||||
CORS_ALLOW_ALL_ORIGINS = True
|
CORS_ALLOW_ALL_ORIGINS = True
|
||||||
|
@ -388,11 +375,19 @@ CELERY_RESULT_BACKEND = get_secret("CELERY_RESULT_BACKEND")
|
||||||
CELERY_RESULT_EXTENDED = True
|
CELERY_RESULT_EXTENDED = True
|
||||||
|
|
||||||
# Celery Beat Options
|
# Celery Beat Options
|
||||||
CELERY_BEAT_SCHEDULER = 'django_celery_beat.schedulers:DatabaseScheduler'
|
CELERY_BEAT_SCHEDULER = "django_celery_beat.schedulers:DatabaseScheduler"
|
||||||
|
|
||||||
# Maximum number of rows that can be updated within the Django admin panel
|
# Maximum number of rows that can be updated within the Django admin panel
|
||||||
DATA_UPLOAD_MAX_NUMBER_FIELDS = 20480
|
DATA_UPLOAD_MAX_NUMBER_FIELDS = 20480
|
||||||
|
|
||||||
GRAPH_MODELS = {
|
GRAPH_MODELS = {
|
||||||
'app_labels': ['accounts', 'user_groups', 'billing', 'emails', 'payments', 'subscriptions', 'search_results']
|
"app_labels": [
|
||||||
|
"accounts",
|
||||||
|
"user_groups",
|
||||||
|
"billing",
|
||||||
|
"emails",
|
||||||
|
"payments",
|
||||||
|
"subscriptions",
|
||||||
|
"search_results",
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from django.urls import path, include
|
from django.urls import path, include
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('api/v1/', include('api.urls')),
|
path("api/v1/", include("api.urls")),
|
||||||
]
|
]
|
||||||
|
|
|
@ -11,6 +11,6 @@ import os
|
||||||
|
|
||||||
from django.core.wsgi import get_wsgi_application
|
from django.core.wsgi import get_wsgi_application
|
||||||
|
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
|
||||||
|
|
||||||
application = get_wsgi_application()
|
application = get_wsgi_application()
|
||||||
|
|
|
@ -2,5 +2,5 @@ from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class EmailsConfig(AppConfig):
|
class EmailsConfig(AppConfig):
|
||||||
default_auto_field = 'django.db.models.BigAutoField'
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
name = 'emails'
|
name = "emails"
|
||||||
|
|
|
@ -3,11 +3,11 @@ from django.utils import timezone
|
||||||
|
|
||||||
|
|
||||||
class ActivationEmail(email.ActivationEmail):
|
class ActivationEmail(email.ActivationEmail):
|
||||||
template_name = 'email_activation.html'
|
template_name = "email_activation.html"
|
||||||
|
|
||||||
|
|
||||||
class PasswordResetEmail(email.PasswordResetEmail):
|
class PasswordResetEmail(email.PasswordResetEmail):
|
||||||
template_name = 'password_change.html'
|
template_name = "password_change.html"
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionAvailedEmail(email.BaseEmailMessage):
|
class SubscriptionAvailedEmail(email.BaseEmailMessage):
|
||||||
|
@ -19,7 +19,7 @@ class SubscriptionAvailedEmail(email.BaseEmailMessage):
|
||||||
context["subscription_plan"] = context.get("subscription_plan")
|
context["subscription_plan"] = context.get("subscription_plan")
|
||||||
context["subscription"] = context.get("subscription")
|
context["subscription"] = context.get("subscription")
|
||||||
context["price_paid"] = context.get("price_paid")
|
context["price_paid"] = context.get("price_paid")
|
||||||
context['date'] = timezone.now().strftime("%B %d, %I:%M %p")
|
context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
|
||||||
context.update(self.context)
|
context.update(self.context)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class SubscriptionRefundedEmail(email.BaseEmailMessage):
|
||||||
context["user"] = context.get("user")
|
context["user"] = context.get("user")
|
||||||
context["subscription_plan"] = context.get("subscription_plan")
|
context["subscription_plan"] = context.get("subscription_plan")
|
||||||
context["refund"] = context.get("refund")
|
context["refund"] = context.get("refund")
|
||||||
context['date'] = timezone.now().strftime("%B %d, %I:%M %p")
|
context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
|
||||||
context.update(self.context)
|
context.update(self.context)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
@ -44,6 +44,6 @@ class SubscriptionCancelledEmail(email.BaseEmailMessage):
|
||||||
context = super().get_context_data()
|
context = super().get_context_data()
|
||||||
context["user"] = context.get("user")
|
context["user"] = context.get("user")
|
||||||
context["subscription_plan"] = context.get("subscription_plan")
|
context["subscription_plan"] = context.get("subscription_plan")
|
||||||
context['date'] = timezone.now().strftime("%B %d, %I:%M %p")
|
context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
|
||||||
context.update(self.context)
|
context.update(self.context)
|
||||||
return context
|
return context
|
||||||
|
|
|
@ -6,7 +6,7 @@ import sys
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run administrative tasks."""
|
"""Run administrative tasks."""
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
|
||||||
try:
|
try:
|
||||||
from django.core.management import execute_from_command_line
|
from django.core.management import execute_from_command_line
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
|
@ -18,5 +18,5 @@ def main():
|
||||||
execute_from_command_line(sys.argv)
|
execute_from_command_line(sys.argv)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -6,5 +6,5 @@ from .models import Notification
|
||||||
@admin.register(Notification)
|
@admin.register(Notification)
|
||||||
class NotificationAdmin(ModelAdmin):
|
class NotificationAdmin(ModelAdmin):
|
||||||
model = Notification
|
model = Notification
|
||||||
search_fields = ('id', 'content')
|
search_fields = ("id", "content")
|
||||||
list_display = ['id', 'dismissed']
|
list_display = ["id", "dismissed"]
|
||||||
|
|
|
@ -2,8 +2,8 @@ from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class NotificationsConfig(AppConfig):
|
class NotificationsConfig(AppConfig):
|
||||||
default_auto_field = 'django.db.models.BigAutoField'
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
name = 'notifications'
|
name = "notifications"
|
||||||
|
|
||||||
def ready(self):
|
def ready(self):
|
||||||
import notifications.signals
|
import notifications.signals
|
||||||
|
|
|
@ -15,13 +15,27 @@ class Migration(migrations.Migration):
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name='Notification',
|
name="Notification",
|
||||||
fields=[
|
fields=[
|
||||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
(
|
||||||
('content', models.CharField(max_length=1000, null=True)),
|
"id",
|
||||||
('timestamp', models.DateTimeField(auto_now_add=True)),
|
models.BigAutoField(
|
||||||
('dismissed', models.BooleanField(default=False)),
|
auto_created=True,
|
||||||
('recipient', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("content", models.CharField(max_length=1000, null=True)),
|
||||||
|
("timestamp", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("dismissed", models.BooleanField(default=False)),
|
||||||
|
(
|
||||||
|
"recipient",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to=settings.AUTH_USER_MODEL,
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,8 +2,7 @@ from django.db import models
|
||||||
|
|
||||||
|
|
||||||
class Notification(models.Model):
|
class Notification(models.Model):
|
||||||
recipient = models.ForeignKey(
|
recipient = models.ForeignKey("accounts.CustomUser", on_delete=models.CASCADE)
|
||||||
'accounts.CustomUser', on_delete=models.CASCADE)
|
|
||||||
content = models.CharField(max_length=1000, null=True)
|
content = models.CharField(max_length=1000, null=True)
|
||||||
timestamp = models.DateTimeField(auto_now_add=True, editable=False)
|
timestamp = models.DateTimeField(auto_now_add=True, editable=False)
|
||||||
dismissed = models.BooleanField(default=False)
|
dismissed = models.BooleanField(default=False)
|
||||||
|
|
|
@ -3,10 +3,9 @@ from notifications.models import Notification
|
||||||
|
|
||||||
|
|
||||||
class NotificationSerializer(serializers.ModelSerializer):
|
class NotificationSerializer(serializers.ModelSerializer):
|
||||||
timestamp = serializers.DateTimeField(
|
timestamp = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
|
||||||
format="%m-%d-%Y %I:%M %p", read_only=True)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Notification
|
model = Notification
|
||||||
fields = '__all__'
|
fields = "__all__"
|
||||||
read_only_fields = ('id', 'recipient', 'content', 'timestamp')
|
read_only_fields = ("id", "recipient", "content", "timestamp")
|
||||||
|
|
|
@ -9,5 +9,5 @@ from django.core.cache import cache
|
||||||
@receiver(post_save, sender=Notification)
|
@receiver(post_save, sender=Notification)
|
||||||
def clear_cache_after_notification_update(sender, instance, **kwargs):
|
def clear_cache_after_notification_update(sender, instance, **kwargs):
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete('notifications')
|
cache.delete("notifications")
|
||||||
cache.delete(f'notifications_user:{instance.recipient.id}')
|
cache.delete(f"notifications_user:{instance.recipient.id}")
|
||||||
|
|
|
@ -9,5 +9,4 @@ def cleanup_notifications():
|
||||||
three_days_ago = timezone.now() - timezone.timedelta(days=3)
|
three_days_ago = timezone.now() - timezone.timedelta(days=3)
|
||||||
|
|
||||||
# Delete notifications that are older than 3 days and dismissed
|
# Delete notifications that are older than 3 days and dismissed
|
||||||
Notification.objects.filter(
|
Notification.objects.filter(dismissed=True, timestamp__lte=three_days_ago).delete()
|
||||||
dismissed=True, timestamp__lte=three_days_ago).delete()
|
|
||||||
|
|
|
@ -3,8 +3,7 @@ from notifications.views import NotificationViewSet
|
||||||
from rest_framework.routers import DefaultRouter
|
from rest_framework.routers import DefaultRouter
|
||||||
|
|
||||||
router = DefaultRouter()
|
router = DefaultRouter()
|
||||||
router.register(r'', NotificationViewSet,
|
router.register(r"", NotificationViewSet, basename="Notifications")
|
||||||
basename="Notifications")
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('', include(router.urls)),
|
path("", include(router.urls)),
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,30 +6,33 @@ from django.core.cache import cache
|
||||||
|
|
||||||
|
|
||||||
class NotificationViewSet(viewsets.ModelViewSet):
|
class NotificationViewSet(viewsets.ModelViewSet):
|
||||||
http_method_names = ['get', 'patch', 'delete']
|
http_method_names = ["get", "patch", "delete"]
|
||||||
serializer_class = NotificationSerializer
|
serializer_class = NotificationSerializer
|
||||||
queryset = Notification.objects.all()
|
queryset = Notification.objects.all()
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
user = self.request.user
|
user = self.request.user
|
||||||
key = f'notifications_user:{user.id}'
|
key = f"notifications_user:{user.id}"
|
||||||
queryset = cache.get(key)
|
queryset = cache.get(key)
|
||||||
if not queryset:
|
if not queryset:
|
||||||
queryset = Notification.objects.filter(
|
queryset = Notification.objects.filter(recipient=user).order_by(
|
||||||
recipient=user).order_by('-timestamp')
|
"-timestamp"
|
||||||
cache.set(key, queryset, 60*60)
|
)
|
||||||
|
cache.set(key, queryset, 60 * 60)
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
def update(self, request, *args, **kwargs):
|
def update(self, request, *args, **kwargs):
|
||||||
instance = self.get_object()
|
instance = self.get_object()
|
||||||
if instance.recipient != request.user:
|
if instance.recipient != request.user:
|
||||||
raise PermissionDenied(
|
raise PermissionDenied(
|
||||||
"You do not have permission to update this notification.")
|
"You do not have permission to update this notification."
|
||||||
|
)
|
||||||
return super().update(request, *args, **kwargs)
|
return super().update(request, *args, **kwargs)
|
||||||
|
|
||||||
def destroy(self, request, *args, **kwargs):
|
def destroy(self, request, *args, **kwargs):
|
||||||
instance = self.get_object()
|
instance = self.get_object()
|
||||||
if instance.recipient != request.user:
|
if instance.recipient != request.user:
|
||||||
raise PermissionDenied(
|
raise PermissionDenied(
|
||||||
"You do not have permission to delete this notification.")
|
"You do not have permission to delete this notification."
|
||||||
|
)
|
||||||
return super().destroy(request, *args, **kwargs)
|
return super().destroy(request, *args, **kwargs)
|
||||||
|
|
|
@ -2,5 +2,5 @@ from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class PaymentsConfig(AppConfig):
|
class PaymentsConfig(AppConfig):
|
||||||
default_auto_field = 'django.db.models.BigAutoField'
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
name = 'payments'
|
name = "payments"
|
||||||
|
|
|
@ -3,6 +3,6 @@ from payments import views
|
||||||
|
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('checkout_session/', views.StripeCheckoutView.as_view()),
|
path("checkout_session/", views.StripeCheckoutView.as_view()),
|
||||||
path('webhook/', views.stripe_webhook_view, name='Stripe Webhook'),
|
path("webhook/", views.stripe_webhook_view, name="Stripe Webhook"),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,4 +1,10 @@
|
||||||
from config.settings import STRIPE_SECRET_KEY, STRIPE_SECRET_WEBHOOK, URL_SCHEME, FRONTEND_ADDRESS, FRONTEND_PORT
|
from config.settings import (
|
||||||
|
STRIPE_SECRET_KEY,
|
||||||
|
STRIPE_SECRET_WEBHOOK,
|
||||||
|
URL_SCHEME,
|
||||||
|
FRONTEND_ADDRESS,
|
||||||
|
FRONTEND_PORT,
|
||||||
|
)
|
||||||
from rest_framework.permissions import IsAuthenticated
|
from rest_framework.permissions import IsAuthenticated
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
@ -12,16 +18,19 @@ from accounts.models import CustomUser
|
||||||
from rest_framework.decorators import api_view
|
from rest_framework.decorators import api_view
|
||||||
from subscriptions.tasks import get_user_subscription
|
from subscriptions.tasks import get_user_subscription
|
||||||
import json
|
import json
|
||||||
from emails.templates import SubscriptionAvailedEmail, SubscriptionRefundedEmail, SubscriptionCancelledEmail
|
from emails.templates import (
|
||||||
|
SubscriptionAvailedEmail,
|
||||||
|
SubscriptionRefundedEmail,
|
||||||
|
SubscriptionCancelledEmail,
|
||||||
|
)
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from payments.serializers import CheckoutSerializer
|
from payments.serializers import CheckoutSerializer
|
||||||
from drf_spectacular.utils import extend_schema
|
from drf_spectacular.utils import extend_schema
|
||||||
|
|
||||||
stripe.api_key = STRIPE_SECRET_KEY
|
stripe.api_key = STRIPE_SECRET_KEY
|
||||||
|
|
||||||
|
|
||||||
@extend_schema(
|
@extend_schema(request=CheckoutSerializer)
|
||||||
request=CheckoutSerializer
|
|
||||||
)
|
|
||||||
class StripeCheckoutView(APIView):
|
class StripeCheckoutView(APIView):
|
||||||
permission_classes = [IsAuthenticated]
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
|
@ -30,41 +39,46 @@ class StripeCheckoutView(APIView):
|
||||||
# Get subscription ID from POST
|
# Get subscription ID from POST
|
||||||
USER = CustomUser.objects.get(id=self.request.user.id)
|
USER = CustomUser.objects.get(id=self.request.user.id)
|
||||||
data = json.loads(request.body)
|
data = json.loads(request.body)
|
||||||
subscription_id = data.get('subscription_id')
|
subscription_id = data.get("subscription_id")
|
||||||
annual = data.get('annual')
|
annual = data.get("annual")
|
||||||
|
|
||||||
# Validation for subscription_id field
|
# Validation for subscription_id field
|
||||||
try:
|
try:
|
||||||
subscription_id = int(subscription_id)
|
subscription_id = int(subscription_id)
|
||||||
except:
|
except:
|
||||||
return Response({
|
return Response(
|
||||||
'error': 'Invalid value specified in subscription_id field'
|
{"error": "Invalid value specified in subscription_id field"},
|
||||||
}, status=status.HTTP_403_FORBIDDEN)
|
status=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
# Validation for annual field
|
# Validation for annual field
|
||||||
try:
|
try:
|
||||||
annual = bool(annual)
|
annual = bool(annual)
|
||||||
except:
|
except:
|
||||||
return Response({
|
return Response(
|
||||||
'error': 'Invalid value specified in annual field'
|
{"error": "Invalid value specified in annual field"},
|
||||||
}, status=status.HTTP_403_FORBIDDEN)
|
status=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
# Return an error if the user already has an active subscription
|
# Return an error if the user already has an active subscription
|
||||||
EXISTING_SUBSCRIPTION = get_user_subscription(USER.id)
|
EXISTING_SUBSCRIPTION = get_user_subscription(USER.id)
|
||||||
if EXISTING_SUBSCRIPTION:
|
if EXISTING_SUBSCRIPTION:
|
||||||
return Response({
|
return Response(
|
||||||
'error': f'User is already subscribed to: {EXISTING_SUBSCRIPTION.subscription.name}'
|
{
|
||||||
}, status=status.HTTP_403_FORBIDDEN)
|
"error": f"User is already subscribed to: {EXISTING_SUBSCRIPTION.subscription.name}"
|
||||||
|
},
|
||||||
|
status=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
# Attempt to query the subscription
|
# Attempt to query the subscription
|
||||||
SUBSCRIPTION = SubscriptionPlan.objects.filter(
|
SUBSCRIPTION = SubscriptionPlan.objects.filter(id=subscription_id).first()
|
||||||
id=subscription_id).first()
|
|
||||||
|
|
||||||
# Return an error if the plan does not exist
|
# Return an error if the plan does not exist
|
||||||
if not SUBSCRIPTION:
|
if not SUBSCRIPTION:
|
||||||
return Response({
|
return Response(
|
||||||
'error': 'Subscription plan not found'
|
{"error": "Subscription plan not found"},
|
||||||
}, status=status.HTTP_404_NOT_FOUND)
|
status=status.HTTP_404_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
# Get the stripe_price_id from the related StripePrice instances
|
# Get the stripe_price_id from the related StripePrice instances
|
||||||
if annual:
|
if annual:
|
||||||
|
@ -74,52 +88,58 @@ class StripeCheckoutView(APIView):
|
||||||
|
|
||||||
# Return 404 if no price is set
|
# Return 404 if no price is set
|
||||||
if not PRICE:
|
if not PRICE:
|
||||||
return Response({
|
return Response(
|
||||||
'error': 'Specified price does not exist for plan'
|
{"error": "Specified price does not exist for plan"},
|
||||||
}, status=status.HTTP_404_NOT_FOUND)
|
status=status.HTTP_404_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
PRICE_ID = PRICE.stripe_price_id
|
PRICE_ID = PRICE.stripe_price_id
|
||||||
prorated = PRICE.prorated
|
prorated = PRICE.prorated
|
||||||
|
|
||||||
# Return an error if a user is in a user_group and is availing pro-rated plans
|
# Return an error if a user is in a user_group and is availing pro-rated plans
|
||||||
if not USER.user_group and SUBSCRIPTION.group_exclusive:
|
if not USER.user_group and SUBSCRIPTION.group_exclusive:
|
||||||
return Response({
|
return Response(
|
||||||
'error': 'Regular users cannot avail prorated plans'
|
{"error": "Regular users cannot avail prorated plans"},
|
||||||
}, status=status.HTTP_403_FORBIDDEN)
|
status=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
success_url = f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}' + \
|
success_url = (
|
||||||
'/user/subscription/payment?success=true&agency=False&session_id={CHECKOUT_SESSION_ID}'
|
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}"
|
||||||
cancel_url = f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}' + \
|
+ "/user/subscription/payment?success=true&agency=False&session_id={CHECKOUT_SESSION_ID}"
|
||||||
'/user/subscription/payment?success=false&user_group=False'
|
)
|
||||||
|
cancel_url = (
|
||||||
|
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}"
|
||||||
|
+ "/user/subscription/payment?success=false&user_group=False"
|
||||||
|
)
|
||||||
|
|
||||||
checkout_session = stripe.checkout.Session.create(
|
checkout_session = stripe.checkout.Session.create(
|
||||||
line_items=[
|
line_items=[
|
||||||
{
|
(
|
||||||
'price': PRICE_ID,
|
{"price": PRICE_ID, "quantity": 1}
|
||||||
'quantity': 1
|
if not prorated
|
||||||
} if not prorated else
|
else {
|
||||||
{
|
"price": PRICE_ID,
|
||||||
'price': PRICE_ID,
|
}
|
||||||
}
|
)
|
||||||
],
|
],
|
||||||
mode='subscription',
|
mode="subscription",
|
||||||
payment_method_types=['card'],
|
payment_method_types=["card"],
|
||||||
success_url=success_url,
|
success_url=success_url,
|
||||||
cancel_url=cancel_url,
|
cancel_url=cancel_url,
|
||||||
)
|
)
|
||||||
return Response({"url": checkout_session.url})
|
return Response({"url": checkout_session.url})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(str(e))
|
logging.error(str(e))
|
||||||
return Response({
|
return Response(
|
||||||
'error': str(e)
|
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ api_view(['POST'])
|
@api_view(["POST"])
|
||||||
@ csrf_exempt
|
@csrf_exempt
|
||||||
def stripe_webhook_view(request):
|
def stripe_webhook_view(request):
|
||||||
payload = request.body
|
payload = request.body
|
||||||
sig_header = request.META['HTTP_STRIPE_SIGNATURE']
|
sig_header = request.META["HTTP_STRIPE_SIGNATURE"]
|
||||||
event = None
|
event = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -133,12 +153,12 @@ def stripe_webhook_view(request):
|
||||||
# Invalid signature
|
# Invalid signature
|
||||||
return Response(status=401)
|
return Response(status=401)
|
||||||
|
|
||||||
if event['type'] == 'customer.subscription.created':
|
if event["type"] == "customer.subscription.created":
|
||||||
subscription = event['data']['object']
|
subscription = event["data"]["object"]
|
||||||
# Get the Invoice object from the Subscription object
|
# Get the Invoice object from the Subscription object
|
||||||
invoice = stripe.Invoice.retrieve(subscription['latest_invoice'])
|
invoice = stripe.Invoice.retrieve(subscription["latest_invoice"])
|
||||||
# Get the Charge object from the Invoice object
|
# Get the Charge object from the Invoice object
|
||||||
charge = stripe.Charge.retrieve(invoice['charge'])
|
charge = stripe.Charge.retrieve(invoice["charge"])
|
||||||
|
|
||||||
# Get paying user
|
# Get paying user
|
||||||
customer = stripe.Customer.retrieve(subscription["customer"])
|
customer = stripe.Customer.retrieve(subscription["customer"])
|
||||||
|
@ -146,18 +166,20 @@ def stripe_webhook_view(request):
|
||||||
|
|
||||||
product = subscription["items"]["data"][0]
|
product = subscription["items"]["data"][0]
|
||||||
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
|
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
|
||||||
stripe_product_id=product["plan"]["product"])
|
stripe_product_id=product["plan"]["product"]
|
||||||
|
)
|
||||||
SUBSCRIPTION = UserSubscription.objects.create(
|
SUBSCRIPTION = UserSubscription.objects.create(
|
||||||
subscription=SUBSCRIPTION_PLAN,
|
subscription=SUBSCRIPTION_PLAN,
|
||||||
annual=product["plan"]["interval"] == "year",
|
annual=product["plan"]["interval"] == "year",
|
||||||
valid=True,
|
valid=True,
|
||||||
user=USER,
|
user=USER,
|
||||||
stripe_id=subscription['id'])
|
stripe_id=subscription["id"],
|
||||||
|
)
|
||||||
email = SubscriptionAvailedEmail()
|
email = SubscriptionAvailedEmail()
|
||||||
|
|
||||||
paid = {
|
paid = {
|
||||||
"amount": charge['amount']/100,
|
"amount": charge["amount"] / 100,
|
||||||
"currency": str(charge['currency']).upper()
|
"currency": str(charge["currency"]).upper(),
|
||||||
}
|
}
|
||||||
|
|
||||||
email.context = {
|
email.context = {
|
||||||
|
@ -169,19 +191,20 @@ def stripe_webhook_view(request):
|
||||||
email.send(to=[customer.email])
|
email.send(to=[customer.email])
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'billing_user:{USER.id}')
|
cache.delete(f"billing_user:{USER.id}")
|
||||||
cache.delete(f'subscriptions_user:{USER.id}')
|
cache.delete(f"subscriptions_user:{USER.id}")
|
||||||
|
|
||||||
# On chargebacks/refunds, invalidate the subscription
|
# On chargebacks/refunds, invalidate the subscription
|
||||||
elif event['type'] == 'charge.refunded':
|
elif event["type"] == "charge.refunded":
|
||||||
charge = event['data']['object']
|
charge = event["data"]["object"]
|
||||||
|
|
||||||
# Get the Invoice object from the Charge object
|
# Get the Invoice object from the Charge object
|
||||||
invoice = stripe.Invoice.retrieve(charge['invoice'])
|
invoice = stripe.Invoice.retrieve(charge["invoice"])
|
||||||
|
|
||||||
# Check if the subscription exists
|
# Check if the subscription exists
|
||||||
SUBSCRIPTION = UserSubscription.objects.filter(
|
SUBSCRIPTION = UserSubscription.objects.filter(
|
||||||
stripe_id=invoice['subscription']).first()
|
stripe_id=invoice["subscription"]
|
||||||
|
).first()
|
||||||
|
|
||||||
if not (SUBSCRIPTION):
|
if not (SUBSCRIPTION):
|
||||||
return HttpResponse(status=404)
|
return HttpResponse(status=404)
|
||||||
|
@ -196,8 +219,8 @@ def stripe_webhook_view(request):
|
||||||
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
||||||
|
|
||||||
refund = {
|
refund = {
|
||||||
"amount": charge['amount_refunded']/100,
|
"amount": charge["amount_refunded"] / 100,
|
||||||
"currency": str(charge['currency']).upper()
|
"currency": str(charge["currency"]).upper(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Send an email
|
# Send an email
|
||||||
|
@ -206,13 +229,13 @@ def stripe_webhook_view(request):
|
||||||
email.context = {
|
email.context = {
|
||||||
"user": USER,
|
"user": USER,
|
||||||
"subscription_plan": SUBSCRIPTION_PLAN,
|
"subscription_plan": SUBSCRIPTION_PLAN,
|
||||||
"refund": refund
|
"refund": refund,
|
||||||
}
|
}
|
||||||
|
|
||||||
email.send(to=[USER.email])
|
email.send(to=[USER.email])
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'billing_user:{USER.id}')
|
cache.delete(f"billing_user:{USER.id}")
|
||||||
|
|
||||||
elif SUBSCRIPTION.user_group:
|
elif SUBSCRIPTION.user_group:
|
||||||
OWNER = SUBSCRIPTION.user_group.owner
|
OWNER = SUBSCRIPTION.user_group.owner
|
||||||
|
@ -223,8 +246,8 @@ def stripe_webhook_view(request):
|
||||||
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
||||||
|
|
||||||
refund = {
|
refund = {
|
||||||
"amount": charge['amount_refunded']/100,
|
"amount": charge["amount_refunded"] / 100,
|
||||||
"currency": str(charge['currency']).upper()
|
"currency": str(charge["currency"]).upper(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Send en email
|
# Send en email
|
||||||
|
@ -233,36 +256,38 @@ def stripe_webhook_view(request):
|
||||||
email.context = {
|
email.context = {
|
||||||
"user": OWNER,
|
"user": OWNER,
|
||||||
"subscription_plan": SUBSCRIPTION_PLAN,
|
"subscription_plan": SUBSCRIPTION_PLAN,
|
||||||
"refund": refund
|
"refund": refund,
|
||||||
}
|
}
|
||||||
email.send(to=[OWNER.email])
|
email.send(to=[OWNER.email])
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'billing_user:{USER.id}')
|
cache.delete(f"billing_user:{USER.id}")
|
||||||
cache.delete(f'subscriptions_user:{USER.id}')
|
cache.delete(f"subscriptions_user:{USER.id}")
|
||||||
|
|
||||||
elif event['type'] == 'customer.subscription.updated':
|
elif event["type"] == "customer.subscription.updated":
|
||||||
subscription = event['data']['object']
|
subscription = event["data"]["object"]
|
||||||
|
|
||||||
# Check if the subscription exists
|
# Check if the subscription exists
|
||||||
SUBSCRIPTION = UserSubscription.objects.filter(
|
SUBSCRIPTION = UserSubscription.objects.filter(
|
||||||
stripe_id=subscription['id']).first()
|
stripe_id=subscription["id"]
|
||||||
|
).first()
|
||||||
|
|
||||||
if not (SUBSCRIPTION):
|
if not (SUBSCRIPTION):
|
||||||
return HttpResponse(status=404)
|
return HttpResponse(status=404)
|
||||||
|
|
||||||
# Check if a subscription has been upgraded/downgraded
|
# Check if a subscription has been upgraded/downgraded
|
||||||
new_stripe_product_id = subscription['items']['data'][0]['plan']['product']
|
new_stripe_product_id = subscription["items"]["data"][0]["plan"]["product"]
|
||||||
current_stripe_product_id = SUBSCRIPTION.subscription.stripe_product_id
|
current_stripe_product_id = SUBSCRIPTION.subscription.stripe_product_id
|
||||||
if new_stripe_product_id != current_stripe_product_id:
|
if new_stripe_product_id != current_stripe_product_id:
|
||||||
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
|
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
|
||||||
stripe_product_id=new_stripe_product_id)
|
stripe_product_id=new_stripe_product_id
|
||||||
|
)
|
||||||
SUBSCRIPTION.subscription = SUBSCRIPTION_PLAN
|
SUBSCRIPTION.subscription = SUBSCRIPTION_PLAN
|
||||||
SUBSCRIPTION.save()
|
SUBSCRIPTION.save()
|
||||||
# TODO: Add a plan upgraded email message here
|
# TODO: Add a plan upgraded email message here
|
||||||
|
|
||||||
# Subscription activation/reactivation
|
# Subscription activation/reactivation
|
||||||
if subscription['status'] == 'active':
|
if subscription["status"] == "active":
|
||||||
SUBSCRIPTION.valid = True
|
SUBSCRIPTION.valid = True
|
||||||
SUBSCRIPTION.save()
|
SUBSCRIPTION.save()
|
||||||
|
|
||||||
|
@ -270,26 +295,24 @@ def stripe_webhook_view(request):
|
||||||
USER = SUBSCRIPTION.user
|
USER = SUBSCRIPTION.user
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'billing_user:{USER.id}')
|
cache.delete(f"billing_user:{USER.id}")
|
||||||
cache.delete(
|
cache.delete(f"subscriptions_user:{USER.id}")
|
||||||
f'subscriptions_user:{USER.id}')
|
|
||||||
|
|
||||||
elif SUBSCRIPTION.user_group:
|
elif SUBSCRIPTION.user_group:
|
||||||
OWNER = SUBSCRIPTION.user_group.owner
|
OWNER = SUBSCRIPTION.user_group.owner
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'billing_user:{OWNER.id}')
|
cache.delete(f"billing_user:{OWNER.id}")
|
||||||
cache.delete(
|
cache.delete(f"subscriptions_usergroup:{SUBSCRIPTION.user_group.id}")
|
||||||
f'subscriptions_usergroup:{SUBSCRIPTION.user_group.id}')
|
|
||||||
|
|
||||||
# TODO: Add notification here to inform users if their plan has been reactivated
|
# TODO: Add notification here to inform users if their plan has been reactivated
|
||||||
|
|
||||||
elif subscription['status'] == 'past_due':
|
elif subscription["status"] == "past_due":
|
||||||
# TODO: Add notification here to inform users if their payment method for an existing subscription payment is failing
|
# TODO: Add notification here to inform users if their payment method for an existing subscription payment is failing
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# If subscriptions get cancelled due to non-payment, invalidate the UserSubscription
|
# If subscriptions get cancelled due to non-payment, invalidate the UserSubscription
|
||||||
elif subscription['status'] == 'cancelled':
|
elif subscription["status"] == "cancelled":
|
||||||
if SUBSCRIPTION.user:
|
if SUBSCRIPTION.user:
|
||||||
USER = SUBSCRIPTION.user
|
USER = SUBSCRIPTION.user
|
||||||
|
|
||||||
|
@ -310,8 +333,8 @@ def stripe_webhook_view(request):
|
||||||
email.send(to=[USER.email])
|
email.send(to=[USER.email])
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'billing_user:{USER.id}')
|
cache.delete(f"billing_user:{USER.id}")
|
||||||
cache.delete(f'subscriptions_user:{USER.id}')
|
cache.delete(f"subscriptions_user:{USER.id}")
|
||||||
|
|
||||||
elif SUBSCRIPTION.user_group:
|
elif SUBSCRIPTION.user_group:
|
||||||
OWNER = SUBSCRIPTION.user_group.owner
|
OWNER = SUBSCRIPTION.user_group.owner
|
||||||
|
@ -325,24 +348,21 @@ def stripe_webhook_view(request):
|
||||||
|
|
||||||
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
||||||
|
|
||||||
email.context = {
|
email.context = {"user": OWNER, "subscription_plan": SUBSCRIPTION_PLAN}
|
||||||
"user": OWNER,
|
|
||||||
"subscription_plan": SUBSCRIPTION_PLAN
|
|
||||||
}
|
|
||||||
email.send(to=[OWNER.email])
|
email.send(to=[OWNER.email])
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'billing_user:{OWNER.id}')
|
cache.delete(f"billing_user:{OWNER.id}")
|
||||||
cache.delete(
|
cache.delete(f"subscriptions_usergroup:{SUBSCRIPTION.user_group.id}")
|
||||||
f'subscriptions_usergroup:{SUBSCRIPTION.user_group.id}')
|
|
||||||
|
|
||||||
# If a subscription gets cancelled, invalidate it
|
# If a subscription gets cancelled, invalidate it
|
||||||
elif event['type'] == 'customer.subscription.deleted':
|
elif event["type"] == "customer.subscription.deleted":
|
||||||
subscription = event['data']['object']
|
subscription = event["data"]["object"]
|
||||||
|
|
||||||
# Check if the subscription exists
|
# Check if the subscription exists
|
||||||
SUBSCRIPTION = UserSubscription.objects.filter(
|
SUBSCRIPTION = UserSubscription.objects.filter(
|
||||||
stripe_id=subscription['id']).first()
|
stripe_id=subscription["id"]
|
||||||
|
).first()
|
||||||
|
|
||||||
if not (SUBSCRIPTION):
|
if not (SUBSCRIPTION):
|
||||||
return HttpResponse(status=404)
|
return HttpResponse(status=404)
|
||||||
|
@ -367,7 +387,7 @@ def stripe_webhook_view(request):
|
||||||
email.send(to=[USER.email])
|
email.send(to=[USER.email])
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'billing_user:{USER.id}')
|
cache.delete(f"billing_user:{USER.id}")
|
||||||
|
|
||||||
elif SUBSCRIPTION.user_group:
|
elif SUBSCRIPTION.user_group:
|
||||||
OWNER = SUBSCRIPTION.user_group.owner
|
OWNER = SUBSCRIPTION.user_group.owner
|
||||||
|
@ -381,14 +401,11 @@ def stripe_webhook_view(request):
|
||||||
|
|
||||||
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
||||||
|
|
||||||
email.context = {
|
email.context = {"user": OWNER, "subscription_plan": SUBSCRIPTION_PLAN}
|
||||||
"user": OWNER,
|
|
||||||
"subscription_plan": SUBSCRIPTION_PLAN
|
|
||||||
}
|
|
||||||
email.send(to=[OWNER.email])
|
email.send(to=[OWNER.email])
|
||||||
|
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete(f'billing_user:{OWNER.id}')
|
cache.delete(f"billing_user:{OWNER.id}")
|
||||||
|
|
||||||
# Passed signature verification
|
# Passed signature verification
|
||||||
return HttpResponse(status=200)
|
return HttpResponse(status=200)
|
||||||
|
|
|
@ -805,6 +805,9 @@ components:
|
||||||
first_name:
|
first_name:
|
||||||
type: string
|
type: string
|
||||||
maxLength: 150
|
maxLength: 150
|
||||||
|
is_new:
|
||||||
|
type: string
|
||||||
|
readOnly: true
|
||||||
last_name:
|
last_name:
|
||||||
type: string
|
type: string
|
||||||
maxLength: 150
|
maxLength: 150
|
||||||
|
@ -824,6 +827,7 @@ components:
|
||||||
- group_member
|
- group_member
|
||||||
- group_owner
|
- group_owner
|
||||||
- id
|
- id
|
||||||
|
- is_new
|
||||||
- user_group
|
- user_group
|
||||||
- username
|
- username
|
||||||
Notification:
|
Notification:
|
||||||
|
@ -885,6 +889,9 @@ components:
|
||||||
first_name:
|
first_name:
|
||||||
type: string
|
type: string
|
||||||
maxLength: 150
|
maxLength: 150
|
||||||
|
is_new:
|
||||||
|
type: string
|
||||||
|
readOnly: true
|
||||||
last_name:
|
last_name:
|
||||||
type: string
|
type: string
|
||||||
maxLength: 150
|
maxLength: 150
|
||||||
|
|
|
@ -7,12 +7,11 @@ from unfold.contrib.filters.admin import RangeDateFilter
|
||||||
@admin.register(SearchResult)
|
@admin.register(SearchResult)
|
||||||
class SearchResultAdmin(ModelAdmin):
|
class SearchResultAdmin(ModelAdmin):
|
||||||
model = SearchResult
|
model = SearchResult
|
||||||
search_fields = ('id', 'title', 'link')
|
search_fields = ("id", "title", "link")
|
||||||
list_display = ['id', 'title', 'timestamp']
|
list_display = ["id", "title", "timestamp"]
|
||||||
|
|
||||||
list_filter_submit = True
|
list_filter_submit = True
|
||||||
list_filter = ((
|
list_filter = (
|
||||||
"timestamp", RangeDateFilter
|
("timestamp", RangeDateFilter),
|
||||||
), (
|
("timestamp", RangeDateFilter),
|
||||||
"timestamp", RangeDateFilter
|
)
|
||||||
),)
|
|
||||||
|
|
|
@ -2,5 +2,5 @@ from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class SearchResultsConfig(AppConfig):
|
class SearchResultsConfig(AppConfig):
|
||||||
default_auto_field = 'django.db.models.BigAutoField'
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
name = 'search_results'
|
name = "search_results"
|
||||||
|
|
|
@ -7,17 +7,24 @@ class Migration(migrations.Migration):
|
||||||
|
|
||||||
initial = True
|
initial = True
|
||||||
|
|
||||||
dependencies = [
|
dependencies = []
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name='SearchResult',
|
name="SearchResult",
|
||||||
fields=[
|
fields=[
|
||||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
(
|
||||||
('title', models.CharField(max_length=1000)),
|
"id",
|
||||||
('link', models.CharField(max_length=1000)),
|
models.BigAutoField(
|
||||||
('timestamp', models.DateTimeField(auto_now_add=True)),
|
auto_created=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("title", models.CharField(max_length=1000)),
|
||||||
|
("link", models.CharField(max_length=1000)),
|
||||||
|
("timestamp", models.DateTimeField(auto_now_add=True)),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,16 +1,13 @@
|
||||||
|
|
||||||
|
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from .models import SearchResult
|
from .models import SearchResult
|
||||||
|
|
||||||
|
|
||||||
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 0, 'countdown': 5})
|
@shared_task(
|
||||||
|
autoretry_for=(Exception,), retry_kwargs={"max_retries": 0, "countdown": 5}
|
||||||
|
)
|
||||||
def create_search_result(title, link):
|
def create_search_result(title, link):
|
||||||
if SearchResult.objects.filter(title=title, link=link).exists():
|
if SearchResult.objects.filter(title=title, link=link).exists():
|
||||||
return ("SearchResult entry already exists")
|
return "SearchResult entry already exists"
|
||||||
else:
|
else:
|
||||||
SearchResult.objects.create(
|
SearchResult.objects.create(title=title, link=link)
|
||||||
title=title,
|
|
||||||
link=link
|
|
||||||
)
|
|
||||||
return f"Created new SearchResult entry titled: {title}"
|
return f"Created new SearchResult entry titled: {title}"
|
||||||
|
|
|
@ -6,10 +6,24 @@ from unfold.contrib.filters.admin import RangeDateFilter
|
||||||
|
|
||||||
@admin.register(StripePrice)
|
@admin.register(StripePrice)
|
||||||
class StripePriceAdmin(ModelAdmin):
|
class StripePriceAdmin(ModelAdmin):
|
||||||
search_fields = ["id", "lookup_key",
|
search_fields = [
|
||||||
"stripe_price_id","price","currency", "prorated", "annual"]
|
"id",
|
||||||
list_display = ["id", "lookup_key",
|
"lookup_key",
|
||||||
"stripe_price_id", "price", "currency", "prorated", "annual"]
|
"stripe_price_id",
|
||||||
|
"price",
|
||||||
|
"currency",
|
||||||
|
"prorated",
|
||||||
|
"annual",
|
||||||
|
]
|
||||||
|
list_display = [
|
||||||
|
"id",
|
||||||
|
"lookup_key",
|
||||||
|
"stripe_price_id",
|
||||||
|
"price",
|
||||||
|
"currency",
|
||||||
|
"prorated",
|
||||||
|
"annual",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@admin.register(SubscriptionPlan)
|
@admin.register(SubscriptionPlan)
|
||||||
|
@ -21,9 +35,6 @@ class SubscriptionPlanAdmin(ModelAdmin):
|
||||||
@admin.register(UserSubscription)
|
@admin.register(UserSubscription)
|
||||||
class UserSubscriptionAdmin(ModelAdmin):
|
class UserSubscriptionAdmin(ModelAdmin):
|
||||||
list_filter_submit = True
|
list_filter_submit = True
|
||||||
list_filter = ((
|
list_filter = (("date", RangeDateFilter),)
|
||||||
"date", RangeDateFilter
|
list_display = ["id", "__str__", "valid", "annual", "date"]
|
||||||
),)
|
|
||||||
list_display = ["id", "__str__", "valid", "annual",
|
|
||||||
"date"]
|
|
||||||
search_fields = ["id", "date"]
|
search_fields = ["id", "date"]
|
||||||
|
|
|
@ -2,8 +2,8 @@ from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionConfig(AppConfig):
|
class SubscriptionConfig(AppConfig):
|
||||||
default_auto_field = 'django.db.models.BigAutoField'
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
name = 'subscriptions'
|
name = "subscriptions"
|
||||||
|
|
||||||
def ready(self):
|
def ready(self):
|
||||||
import subscriptions.signals
|
import subscriptions.signals
|
||||||
|
|
|
@ -11,46 +11,118 @@ class Migration(migrations.Migration):
|
||||||
initial = True
|
initial = True
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
('user_groups', '0001_initial'),
|
("user_groups", "0001_initial"),
|
||||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||||
]
|
]
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name='StripePrice',
|
name="StripePrice",
|
||||||
fields=[
|
fields=[
|
||||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
(
|
||||||
('annual', models.BooleanField(default=False)),
|
"id",
|
||||||
('stripe_price_id', models.CharField(max_length=100)),
|
models.BigAutoField(
|
||||||
('price', models.DecimalField(decimal_places=2, default=0.0, max_digits=10)),
|
auto_created=True,
|
||||||
('currency', models.CharField(max_length=20)),
|
primary_key=True,
|
||||||
('lookup_key', models.CharField(blank=True, max_length=100, null=True)),
|
serialize=False,
|
||||||
('prorated', models.BooleanField(default=False)),
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("annual", models.BooleanField(default=False)),
|
||||||
|
("stripe_price_id", models.CharField(max_length=100)),
|
||||||
|
(
|
||||||
|
"price",
|
||||||
|
models.DecimalField(decimal_places=2, default=0.0, max_digits=10),
|
||||||
|
),
|
||||||
|
("currency", models.CharField(max_length=20)),
|
||||||
|
("lookup_key", models.CharField(blank=True, max_length=100, null=True)),
|
||||||
|
("prorated", models.BooleanField(default=False)),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name='SubscriptionPlan',
|
name="SubscriptionPlan",
|
||||||
fields=[
|
fields=[
|
||||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
(
|
||||||
('name', models.CharField(max_length=100)),
|
"id",
|
||||||
('description', models.TextField(max_length=1024, null=True)),
|
models.BigAutoField(
|
||||||
('stripe_product_id', models.CharField(max_length=100)),
|
auto_created=True,
|
||||||
('group_exclusive', models.BooleanField(default=False)),
|
primary_key=True,
|
||||||
('annual_price', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='annual_plan', to='subscriptions.stripeprice')),
|
serialize=False,
|
||||||
('monthly_price', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='monthly_plan', to='subscriptions.stripeprice')),
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("name", models.CharField(max_length=100)),
|
||||||
|
("description", models.TextField(max_length=1024, null=True)),
|
||||||
|
("stripe_product_id", models.CharField(max_length=100)),
|
||||||
|
("group_exclusive", models.BooleanField(default=False)),
|
||||||
|
(
|
||||||
|
"annual_price",
|
||||||
|
models.ForeignKey(
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
related_name="annual_plan",
|
||||||
|
to="subscriptions.stripeprice",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"monthly_price",
|
||||||
|
models.ForeignKey(
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
related_name="monthly_plan",
|
||||||
|
to="subscriptions.stripeprice",
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name='UserSubscription',
|
name="UserSubscription",
|
||||||
fields=[
|
fields=[
|
||||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
(
|
||||||
('stripe_id', models.CharField(max_length=100)),
|
"id",
|
||||||
('date', models.DateTimeField(default=django.utils.timezone.now, editable=False)),
|
models.BigAutoField(
|
||||||
('valid', models.BooleanField()),
|
auto_created=True,
|
||||||
('annual', models.BooleanField()),
|
primary_key=True,
|
||||||
('subscription', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='subscriptions.subscriptionplan')),
|
serialize=False,
|
||||||
('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
verbose_name="ID",
|
||||||
('user_group', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='user_groups.usergroup')),
|
),
|
||||||
|
),
|
||||||
|
("stripe_id", models.CharField(max_length=100)),
|
||||||
|
(
|
||||||
|
"date",
|
||||||
|
models.DateTimeField(
|
||||||
|
default=django.utils.timezone.now, editable=False
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("valid", models.BooleanField()),
|
||||||
|
("annual", models.BooleanField()),
|
||||||
|
(
|
||||||
|
"subscription",
|
||||||
|
models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
to="subscriptions.subscriptionplan",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"user",
|
||||||
|
models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to=settings.AUTH_USER_MODEL,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"user_group",
|
||||||
|
models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="user_groups.usergroup",
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from accounts.models import CustomUser
|
from accounts.models import CustomUser
|
||||||
from user_groups.models import UserGroup
|
from user_groups.models import UserGroup
|
||||||
|
@ -25,9 +24,11 @@ class SubscriptionPlan(models.Model):
|
||||||
description = models.TextField(max_length=1024, null=True)
|
description = models.TextField(max_length=1024, null=True)
|
||||||
stripe_product_id = models.CharField(max_length=100)
|
stripe_product_id = models.CharField(max_length=100)
|
||||||
annual_price = models.ForeignKey(
|
annual_price = models.ForeignKey(
|
||||||
StripePrice, on_delete=models.SET_NULL, related_name='annual_plan', null=True)
|
StripePrice, on_delete=models.SET_NULL, related_name="annual_plan", null=True
|
||||||
|
)
|
||||||
monthly_price = models.ForeignKey(
|
monthly_price = models.ForeignKey(
|
||||||
StripePrice, on_delete=models.SET_NULL, related_name='monthly_plan', null=True)
|
StripePrice, on_delete=models.SET_NULL, related_name="monthly_plan", null=True
|
||||||
|
)
|
||||||
group_exclusive = models.BooleanField(default=False)
|
group_exclusive = models.BooleanField(default=False)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -39,11 +40,14 @@ class SubscriptionPlan(models.Model):
|
||||||
|
|
||||||
class UserSubscription(models.Model):
|
class UserSubscription(models.Model):
|
||||||
user = models.ForeignKey(
|
user = models.ForeignKey(
|
||||||
CustomUser, on_delete=models.CASCADE, blank=True, null=True)
|
CustomUser, on_delete=models.CASCADE, blank=True, null=True
|
||||||
|
)
|
||||||
user_group = models.ForeignKey(
|
user_group = models.ForeignKey(
|
||||||
UserGroup, on_delete=models.CASCADE, blank=True, null=True)
|
UserGroup, on_delete=models.CASCADE, blank=True, null=True
|
||||||
|
)
|
||||||
subscription = models.ForeignKey(
|
subscription = models.ForeignKey(
|
||||||
SubscriptionPlan, on_delete=models.SET_NULL, blank=True, null=True)
|
SubscriptionPlan, on_delete=models.SET_NULL, blank=True, null=True
|
||||||
|
)
|
||||||
stripe_id = models.CharField(max_length=100)
|
stripe_id = models.CharField(max_length=100)
|
||||||
date = models.DateTimeField(default=now, editable=False)
|
date = models.DateTimeField(default=now, editable=False)
|
||||||
valid = models.BooleanField()
|
valid = models.BooleanField()
|
||||||
|
@ -51,6 +55,6 @@ class UserSubscription(models.Model):
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.user:
|
if self.user:
|
||||||
return f'Subscription {self.subscription.name} for {self.user}'
|
return f"Subscription {self.subscription.name} for {self.user}"
|
||||||
else:
|
else:
|
||||||
return f'Subscription {self.subscription.name} for {self.user_group}'
|
return f"Subscription {self.subscription.name} for {self.user_group}"
|
||||||
|
|
|
@ -7,38 +7,46 @@ class SimpleStripePriceSerializer(serializers.ModelSerializer):
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = StripePrice
|
model = StripePrice
|
||||||
fields = ['price', 'currency', 'prorated']
|
fields = ["price", "currency", "prorated"]
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionPlanSerializer(serializers.ModelSerializer):
|
class SubscriptionPlanSerializer(serializers.ModelSerializer):
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = SubscriptionPlan
|
model = SubscriptionPlan
|
||||||
fields = ['id', 'name', 'description',
|
fields = [
|
||||||
'annual_price', 'monthly_price', 'group_exclusive']
|
"id",
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
|
"annual_price",
|
||||||
|
"monthly_price",
|
||||||
|
"group_exclusive",
|
||||||
|
]
|
||||||
|
|
||||||
def to_representation(self, instance):
|
def to_representation(self, instance):
|
||||||
representation = super().to_representation(instance)
|
representation = super().to_representation(instance)
|
||||||
representation['annual_price'] = SimpleStripePriceSerializer(
|
representation["annual_price"] = SimpleStripePriceSerializer(
|
||||||
instance.annual_price, many=False).data
|
instance.annual_price, many=False
|
||||||
representation['monthly_price'] = SimpleStripePriceSerializer(
|
).data
|
||||||
instance.monthly_price, many=False).data
|
representation["monthly_price"] = SimpleStripePriceSerializer(
|
||||||
|
instance.monthly_price, many=False
|
||||||
|
).data
|
||||||
return representation
|
return representation
|
||||||
|
|
||||||
|
|
||||||
class UserSubscriptionSerializer(serializers.ModelSerializer):
|
class UserSubscriptionSerializer(serializers.ModelSerializer):
|
||||||
date = serializers.DateTimeField(
|
date = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
|
||||||
format="%m-%d-%Y %I:%M %p", read_only=True)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = UserSubscription
|
model = UserSubscription
|
||||||
fields = ['id', 'user', 'user_group', 'subscription',
|
fields = ["id", "user", "user_group", "subscription", "date", "valid", "annual"]
|
||||||
'date', 'valid', 'annual']
|
|
||||||
|
|
||||||
def to_representation(self, instance):
|
def to_representation(self, instance):
|
||||||
representation = super().to_representation(instance)
|
representation = super().to_representation(instance)
|
||||||
representation['user'] = SimpleCustomUserSerializer(
|
representation["user"] = SimpleCustomUserSerializer(
|
||||||
instance.user, many=False).data
|
instance.user, many=False
|
||||||
representation['subscription'] = SubscriptionPlanSerializer(
|
).data
|
||||||
instance.subscription, many=False).data
|
representation["subscription"] = SubscriptionPlanSerializer(
|
||||||
|
instance.subscription, many=False
|
||||||
|
).data
|
||||||
return representation
|
return representation
|
||||||
|
|
|
@ -4,6 +4,7 @@ from .models import UserSubscription, StripePrice, SubscriptionPlan
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from config.settings import STRIPE_SECRET_KEY
|
from config.settings import STRIPE_SECRET_KEY
|
||||||
import stripe
|
import stripe
|
||||||
|
|
||||||
stripe.api_key = STRIPE_SECRET_KEY
|
stripe.api_key = STRIPE_SECRET_KEY
|
||||||
|
|
||||||
# Template for running actions after user have paid for a subscription
|
# Template for running actions after user have paid for a subscription
|
||||||
|
@ -12,7 +13,7 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||||
@receiver(post_save, sender=SubscriptionPlan)
|
@receiver(post_save, sender=SubscriptionPlan)
|
||||||
def clear_cache_after_plan_updates(sender, instance, **kwargs):
|
def clear_cache_after_plan_updates(sender, instance, **kwargs):
|
||||||
# Clear cache
|
# Clear cache
|
||||||
cache.delete('subscriptionplans')
|
cache.delete("subscriptionplans")
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_save, sender=UserSubscription)
|
@receiver(post_save, sender=UserSubscription)
|
||||||
|
@ -25,8 +26,8 @@ def scan_after_payment(sender, instance, **kwargs):
|
||||||
|
|
||||||
@receiver(post_migrate)
|
@receiver(post_migrate)
|
||||||
def create_subscriptions(sender, **kwargs):
|
def create_subscriptions(sender, **kwargs):
|
||||||
if sender.name == 'subscriptions':
|
if sender.name == "subscriptions":
|
||||||
print('Importing data from Stripe')
|
print("Importing data from Stripe")
|
||||||
created_prices = 0
|
created_prices = 0
|
||||||
created_plans = 0
|
created_plans = 0
|
||||||
skipped_prices = 0
|
skipped_prices = 0
|
||||||
|
@ -35,16 +36,19 @@ def create_subscriptions(sender, **kwargs):
|
||||||
prices = stripe.Price.list(expand=["data.tiers"], active=True)
|
prices = stripe.Price.list(expand=["data.tiers"], active=True)
|
||||||
|
|
||||||
# Create the StripePrice
|
# Create the StripePrice
|
||||||
for price in prices['data']:
|
for price in prices["data"]:
|
||||||
annual = (price['recurring']['interval'] ==
|
annual = (
|
||||||
'year') if price['recurring'] else False
|
(price["recurring"]["interval"] == "year")
|
||||||
|
if price["recurring"]
|
||||||
|
else False
|
||||||
|
)
|
||||||
STRIPE_PRICE, CREATED = StripePrice.objects.get_or_create(
|
STRIPE_PRICE, CREATED = StripePrice.objects.get_or_create(
|
||||||
stripe_price_id=price['id'],
|
stripe_price_id=price["id"],
|
||||||
price=price['unit_amount'] / 100,
|
price=price["unit_amount"] / 100,
|
||||||
annual=annual,
|
annual=annual,
|
||||||
lookup_key=price['lookup_key'],
|
lookup_key=price["lookup_key"],
|
||||||
prorated=price['recurring']['usage_type'] == 'metered',
|
prorated=price["recurring"]["usage_type"] == "metered",
|
||||||
currency=price['currency']
|
currency=price["currency"],
|
||||||
)
|
)
|
||||||
if CREATED:
|
if CREATED:
|
||||||
created_prices += 1
|
created_prices += 1
|
||||||
|
@ -52,13 +56,13 @@ def create_subscriptions(sender, **kwargs):
|
||||||
skipped_prices += 1
|
skipped_prices += 1
|
||||||
|
|
||||||
# Create the SubscriptionPlan
|
# Create the SubscriptionPlan
|
||||||
for product in products['data']:
|
for product in products["data"]:
|
||||||
ANNUAL_PRICE = None
|
ANNUAL_PRICE = None
|
||||||
MONTHLY_PRICE = None
|
MONTHLY_PRICE = None
|
||||||
for price in prices['data']:
|
for price in prices["data"]:
|
||||||
if price['product'] == product['id']:
|
if price["product"] == product["id"]:
|
||||||
STRIPE_PRICE = StripePrice.objects.get(
|
STRIPE_PRICE = StripePrice.objects.get(
|
||||||
stripe_price_id=price['id'],
|
stripe_price_id=price["id"],
|
||||||
)
|
)
|
||||||
if STRIPE_PRICE.annual:
|
if STRIPE_PRICE.annual:
|
||||||
ANNUAL_PRICE = STRIPE_PRICE
|
ANNUAL_PRICE = STRIPE_PRICE
|
||||||
|
@ -66,12 +70,12 @@ def create_subscriptions(sender, **kwargs):
|
||||||
MONTHLY_PRICE = STRIPE_PRICE
|
MONTHLY_PRICE = STRIPE_PRICE
|
||||||
if ANNUAL_PRICE or MONTHLY_PRICE:
|
if ANNUAL_PRICE or MONTHLY_PRICE:
|
||||||
SUBSCRIPTION_PLAN, CREATED = SubscriptionPlan.objects.get_or_create(
|
SUBSCRIPTION_PLAN, CREATED = SubscriptionPlan.objects.get_or_create(
|
||||||
name=product['name'],
|
name=product["name"],
|
||||||
description=product['description'],
|
description=product["description"],
|
||||||
stripe_product_id=product['id'],
|
stripe_product_id=product["id"],
|
||||||
annual_price=ANNUAL_PRICE,
|
annual_price=ANNUAL_PRICE,
|
||||||
monthly_price=MONTHLY_PRICE,
|
monthly_price=MONTHLY_PRICE,
|
||||||
group_exclusive=product['metadata']['group_exclusive'] == 'True'
|
group_exclusive=product["metadata"]["group_exclusive"] == "True",
|
||||||
)
|
)
|
||||||
if CREATED:
|
if CREATED:
|
||||||
created_plans += 1
|
created_plans += 1
|
||||||
|
@ -79,13 +83,12 @@ def create_subscriptions(sender, **kwargs):
|
||||||
skipped_plans += 1
|
skipped_plans += 1
|
||||||
# Skip over plans with missing pricing rates
|
# Skip over plans with missing pricing rates
|
||||||
else:
|
else:
|
||||||
print('Skipping plan' +
|
print("Skipping plan" + product["name"] + "with missing pricing data")
|
||||||
product['name'] + 'with missing pricing data')
|
|
||||||
|
|
||||||
# Assign the StripePrice to the SubscriptionPlan
|
# Assign the StripePrice to the SubscriptionPlan
|
||||||
SUBSCRIPTION_PLAN.save()
|
SUBSCRIPTION_PLAN.save()
|
||||||
|
|
||||||
print('Created', created_plans, 'new plans')
|
print("Created", created_plans, "new plans")
|
||||||
print('Skipped', skipped_plans, 'existing plans')
|
print("Skipped", skipped_plans, "existing plans")
|
||||||
print('Created', created_prices, 'new prices')
|
print("Created", created_prices, "new prices")
|
||||||
print('Skipped', skipped_prices, 'existing prices')
|
print("Skipped", skipped_prices, "existing prices")
|
||||||
|
|
|
@ -12,10 +12,10 @@ def get_user_subscription(user_id):
|
||||||
active_subscriptions = None
|
active_subscriptions = None
|
||||||
if USER.user_group:
|
if USER.user_group:
|
||||||
active_subscriptions = UserSubscription.objects.filter(
|
active_subscriptions = UserSubscription.objects.filter(
|
||||||
user_group=USER.user_group, valid=True)
|
user_group=USER.user_group, valid=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
active_subscriptions = UserSubscription.objects.filter(
|
active_subscriptions = UserSubscription.objects.filter(user=USER, valid=True)
|
||||||
user=USER, valid=True)
|
|
||||||
|
|
||||||
# Return first valid subscription if there is one
|
# Return first valid subscription if there is one
|
||||||
if len(active_subscriptions) > 0:
|
if len(active_subscriptions) > 0:
|
||||||
|
@ -33,7 +33,8 @@ def get_user_group_subscription(user_group):
|
||||||
# Get a list of subscriptions for the specified user
|
# Get a list of subscriptions for the specified user
|
||||||
active_subscriptions = None
|
active_subscriptions = None
|
||||||
active_subscriptions = UserSubscription.objects.filter(
|
active_subscriptions = UserSubscription.objects.filter(
|
||||||
user_group=USER_GROUP, valid=True)
|
user_group=USER_GROUP, valid=True
|
||||||
|
)
|
||||||
|
|
||||||
# Return first valid subscription if there is one
|
# Return first valid subscription if there is one
|
||||||
if len(active_subscriptions) > 0:
|
if len(active_subscriptions) > 0:
|
||||||
|
|
|
@ -3,12 +3,11 @@ from subscriptions import views
|
||||||
from rest_framework.routers import DefaultRouter
|
from rest_framework.routers import DefaultRouter
|
||||||
|
|
||||||
router = DefaultRouter()
|
router = DefaultRouter()
|
||||||
router.register(r'plans', views.SubscriptionPlanViewset,
|
router.register(r"plans", views.SubscriptionPlanViewset, basename="Subscription Plans")
|
||||||
basename="Subscription Plans")
|
router.register(r"self", views.UserSubscriptionViewset, basename="Self Subscriptions")
|
||||||
router.register(r'self', views.UserSubscriptionViewset,
|
router.register(
|
||||||
basename="Self Subscriptions")
|
r"user_group", views.UserGroupSubscriptionViewet, basename="Group Subscriptions"
|
||||||
router.register(r'user_group', views.UserGroupSubscriptionViewet,
|
)
|
||||||
basename="Group Subscriptions")
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('', include(router.urls)),
|
path("", include(router.urls)),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
from subscriptions.serializers import SubscriptionPlanSerializer, UserSubscriptionSerializer
|
from subscriptions.serializers import (
|
||||||
|
SubscriptionPlanSerializer,
|
||||||
|
UserSubscriptionSerializer,
|
||||||
|
)
|
||||||
from subscriptions.models import SubscriptionPlan, UserSubscription
|
from subscriptions.models import SubscriptionPlan, UserSubscription
|
||||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||||
from rest_framework import viewsets
|
from rest_framework import viewsets
|
||||||
|
@ -6,38 +9,38 @@ from django.core.cache import cache
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionPlanViewset(viewsets.ModelViewSet):
|
class SubscriptionPlanViewset(viewsets.ModelViewSet):
|
||||||
http_method_names = ['get']
|
http_method_names = ["get"]
|
||||||
serializer_class = SubscriptionPlanSerializer
|
serializer_class = SubscriptionPlanSerializer
|
||||||
permission_classes = [AllowAny]
|
permission_classes = [AllowAny]
|
||||||
queryset = SubscriptionPlan.objects.all()
|
queryset = SubscriptionPlan.objects.all()
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
key = 'subscriptionplans'
|
key = "subscriptionplans"
|
||||||
queryset = cache.get(key)
|
queryset = cache.get(key)
|
||||||
if not queryset:
|
if not queryset:
|
||||||
queryset = super().get_queryset()
|
queryset = super().get_queryset()
|
||||||
cache.set(key, queryset, 60*60)
|
cache.set(key, queryset, 60 * 60)
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
|
||||||
class UserSubscriptionViewset(viewsets.ModelViewSet):
|
class UserSubscriptionViewset(viewsets.ModelViewSet):
|
||||||
http_method_names = ['get']
|
http_method_names = ["get"]
|
||||||
serializer_class = UserSubscriptionSerializer
|
serializer_class = UserSubscriptionSerializer
|
||||||
permission_classes = [IsAuthenticated]
|
permission_classes = [IsAuthenticated]
|
||||||
queryset = UserSubscription.objects.all()
|
queryset = UserSubscription.objects.all()
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
user = self.request.user
|
user = self.request.user
|
||||||
key = f'subscriptions_user:{user.id}'
|
key = f"subscriptions_user:{user.id}"
|
||||||
queryset = cache.get(key)
|
queryset = cache.get(key)
|
||||||
if not queryset:
|
if not queryset:
|
||||||
queryset = UserSubscription.objects.filter(user=user)
|
queryset = UserSubscription.objects.filter(user=user)
|
||||||
cache.set(key, queryset, 60*60)
|
cache.set(key, queryset, 60 * 60)
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
|
||||||
class UserGroupSubscriptionViewet(viewsets.ModelViewSet):
|
class UserGroupSubscriptionViewet(viewsets.ModelViewSet):
|
||||||
http_method_names = ['get']
|
http_method_names = ["get"]
|
||||||
serializer_class = UserSubscriptionSerializer
|
serializer_class = UserSubscriptionSerializer
|
||||||
permission_classes = [IsAuthenticated]
|
permission_classes = [IsAuthenticated]
|
||||||
queryset = UserSubscription.objects.all()
|
queryset = UserSubscription.objects.all()
|
||||||
|
@ -47,10 +50,9 @@ class UserGroupSubscriptionViewet(viewsets.ModelViewSet):
|
||||||
if not user.user_group:
|
if not user.user_group:
|
||||||
return UserSubscription.objects.none()
|
return UserSubscription.objects.none()
|
||||||
else:
|
else:
|
||||||
key = f'subscriptions_usergroup:{user.user_group.id}'
|
key = f"subscriptions_usergroup:{user.user_group.id}"
|
||||||
queryset = cache.get(key)
|
queryset = cache.get(key)
|
||||||
if not cache:
|
if not cache:
|
||||||
queryset = UserSubscription.objects.filter(
|
queryset = UserSubscription.objects.filter(user_group=user.user_group)
|
||||||
user_group=user.user_group)
|
cache.set(key, queryset, 60 * 60)
|
||||||
cache.set(key, queryset, 60*60)
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
|
@ -7,9 +7,7 @@ from unfold.contrib.filters.admin import RangeDateFilter
|
||||||
@admin.register(UserGroup)
|
@admin.register(UserGroup)
|
||||||
class UserGroupAdmin(ModelAdmin):
|
class UserGroupAdmin(ModelAdmin):
|
||||||
list_filter_submit = True
|
list_filter_submit = True
|
||||||
list_filter = ((
|
list_filter = (("date_created", RangeDateFilter),)
|
||||||
"date_created", RangeDateFilter
|
|
||||||
),)
|
|
||||||
|
|
||||||
list_display = ['id', 'name']
|
list_display = ["id", "name"]
|
||||||
search_fields = ['id', 'name']
|
search_fields = ["id", "name"]
|
||||||
|
|
|
@ -8,16 +8,28 @@ class Migration(migrations.Migration):
|
||||||
|
|
||||||
initial = True
|
initial = True
|
||||||
|
|
||||||
dependencies = [
|
dependencies = []
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name='UserGroup',
|
name="UserGroup",
|
||||||
fields=[
|
fields=[
|
||||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
(
|
||||||
('name', models.CharField(max_length=128)),
|
"id",
|
||||||
('date_created', models.DateTimeField(default=django.utils.timezone.now, editable=False)),
|
models.BigAutoField(
|
||||||
|
auto_created=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("name", models.CharField(max_length=128)),
|
||||||
|
(
|
||||||
|
"date_created",
|
||||||
|
models.DateTimeField(
|
||||||
|
default=django.utils.timezone.now, editable=False
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -8,24 +8,33 @@ from django.db import migrations, models
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
('user_groups', '0001_initial'),
|
("user_groups", "0001_initial"),
|
||||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||||
]
|
]
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name='usergroup',
|
model_name="usergroup",
|
||||||
name='managers',
|
name="managers",
|
||||||
field=models.ManyToManyField(related_name='usergroup_managers', to=settings.AUTH_USER_MODEL),
|
field=models.ManyToManyField(
|
||||||
|
related_name="usergroup_managers", to=settings.AUTH_USER_MODEL
|
||||||
|
),
|
||||||
),
|
),
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name='usergroup',
|
model_name="usergroup",
|
||||||
name='members',
|
name="members",
|
||||||
field=models.ManyToManyField(related_name='usergroup_members', to=settings.AUTH_USER_MODEL),
|
field=models.ManyToManyField(
|
||||||
|
related_name="usergroup_members", to=settings.AUTH_USER_MODEL
|
||||||
|
),
|
||||||
),
|
),
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name='usergroup',
|
model_name="usergroup",
|
||||||
name='owner',
|
name="owner",
|
||||||
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='usergroup_owner', to=settings.AUTH_USER_MODEL),
|
field=models.ForeignKey(
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
related_name="usergroup_owner",
|
||||||
|
to=settings.AUTH_USER_MODEL,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,17 +2,24 @@ from django.db import models
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from config.settings import STRIPE_SECRET_KEY
|
from config.settings import STRIPE_SECRET_KEY
|
||||||
import stripe
|
import stripe
|
||||||
|
|
||||||
stripe.api_key = STRIPE_SECRET_KEY
|
stripe.api_key = STRIPE_SECRET_KEY
|
||||||
|
|
||||||
|
|
||||||
class UserGroup(models.Model):
|
class UserGroup(models.Model):
|
||||||
name = models.CharField(max_length=128, null=False)
|
name = models.CharField(max_length=128, null=False)
|
||||||
owner = models.ForeignKey(
|
owner = models.ForeignKey(
|
||||||
'accounts.CustomUser', on_delete=models.SET_NULL, null=True, related_name='usergroup_owner')
|
"accounts.CustomUser",
|
||||||
|
on_delete=models.SET_NULL,
|
||||||
|
null=True,
|
||||||
|
related_name="usergroup_owner",
|
||||||
|
)
|
||||||
managers = models.ManyToManyField(
|
managers = models.ManyToManyField(
|
||||||
'accounts.CustomUser', related_name='usergroup_managers')
|
"accounts.CustomUser", related_name="usergroup_managers"
|
||||||
|
)
|
||||||
members = models.ManyToManyField(
|
members = models.ManyToManyField(
|
||||||
'accounts.CustomUser', related_name='usergroup_members')
|
"accounts.CustomUser", related_name="usergroup_members"
|
||||||
|
)
|
||||||
date_created = models.DateTimeField(default=now, editable=False)
|
date_created = models.DateTimeField(default=now, editable=False)
|
||||||
|
|
||||||
# Derived from email of owner, may be used for billing
|
# Derived from email of owner, may be used for billing
|
||||||
|
|
|
@ -3,10 +3,9 @@ from .models import UserGroup
|
||||||
|
|
||||||
|
|
||||||
class SimpleUserGroupSerializer(serializers.ModelSerializer):
|
class SimpleUserGroupSerializer(serializers.ModelSerializer):
|
||||||
date_created = serializers.DateTimeField(
|
date_created = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
|
||||||
format="%m-%d-%Y %I:%M %p", read_only=True)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = UserGroup
|
model = UserGroup
|
||||||
fields = ['id', 'name', 'date_created']
|
fields = ["id", "name", "date_created"]
|
||||||
read_only_fields = ['id', 'name', 'date_created']
|
read_only_fields = ["id", "name", "date_created"]
|
||||||
|
|
|
@ -8,15 +8,16 @@ from config.settings import STRIPE_SECRET_KEY, ROOT_DIR
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import stripe
|
import stripe
|
||||||
|
|
||||||
stripe.api_key = STRIPE_SECRET_KEY
|
stripe.api_key = STRIPE_SECRET_KEY
|
||||||
|
|
||||||
|
|
||||||
@receiver(m2m_changed, sender=UserGroup.managers.through)
|
@receiver(m2m_changed, sender=UserGroup.managers.through)
|
||||||
def update_group_managers(sender, instance, action, **kwargs):
|
def update_group_managers(sender, instance, action, **kwargs):
|
||||||
# When adding new managers to a UserGroup, associate them with it
|
# When adding new managers to a UserGroup, associate them with it
|
||||||
if action == 'post_add':
|
if action == "post_add":
|
||||||
# Get the newly added managers
|
# Get the newly added managers
|
||||||
new_managers = kwargs.get('pk_set', set())
|
new_managers = kwargs.get("pk_set", set())
|
||||||
for manager in new_managers:
|
for manager in new_managers:
|
||||||
# Retrieve the member
|
# Retrieve the member
|
||||||
USER = CustomUser.objects.get(pk=manager)
|
USER = CustomUser.objects.get(pk=manager)
|
||||||
|
@ -27,8 +28,8 @@ def update_group_managers(sender, instance, action, **kwargs):
|
||||||
if USER not in instance.members.all():
|
if USER not in instance.members.all():
|
||||||
instance.members.add(USER)
|
instance.members.add(USER)
|
||||||
# When removing managers from a UserGroup, remove their association with it
|
# When removing managers from a UserGroup, remove their association with it
|
||||||
elif action == 'post_remove':
|
elif action == "post_remove":
|
||||||
for manager in kwargs['pk_set']:
|
for manager in kwargs["pk_set"]:
|
||||||
# Retrieve the manager
|
# Retrieve the manager
|
||||||
USER = CustomUser.objects.get(pk=manager)
|
USER = CustomUser.objects.get(pk=manager)
|
||||||
if USER not in instance.members.all():
|
if USER not in instance.members.all():
|
||||||
|
@ -39,9 +40,9 @@ def update_group_managers(sender, instance, action, **kwargs):
|
||||||
@receiver(m2m_changed, sender=UserGroup.members.through)
|
@receiver(m2m_changed, sender=UserGroup.members.through)
|
||||||
def update_group_members(sender, instance, action, **kwargs):
|
def update_group_members(sender, instance, action, **kwargs):
|
||||||
# When adding new members to a UserGroup, associate them with it
|
# When adding new members to a UserGroup, associate them with it
|
||||||
if action == 'post_add':
|
if action == "post_add":
|
||||||
# Get the newly added members
|
# Get the newly added members
|
||||||
new_members = kwargs.get('pk_set', set())
|
new_members = kwargs.get("pk_set", set())
|
||||||
for member in new_members:
|
for member in new_members:
|
||||||
# Retrieve the member
|
# Retrieve the member
|
||||||
USER = CustomUser.objects.get(pk=member)
|
USER = CustomUser.objects.get(pk=member)
|
||||||
|
@ -50,10 +51,13 @@ def update_group_members(sender, instance, action, **kwargs):
|
||||||
USER.user_group = instance
|
USER.user_group = instance
|
||||||
USER.save()
|
USER.save()
|
||||||
# When removing members from a UserGroup, remove their association with it
|
# When removing members from a UserGroup, remove their association with it
|
||||||
elif action == 'post_remove':
|
elif action == "post_remove":
|
||||||
for client in kwargs['pk_set']:
|
for client in kwargs["pk_set"]:
|
||||||
USER = CustomUser.objects.get(pk=client)
|
USER = CustomUser.objects.get(pk=client)
|
||||||
if USER not in instance.members.all() and USER not in instance.managers.all():
|
if (
|
||||||
|
USER not in instance.members.all()
|
||||||
|
and USER not in instance.managers.all()
|
||||||
|
):
|
||||||
USER.user_group = None
|
USER.user_group = None
|
||||||
USER.save()
|
USER.save()
|
||||||
# Update usage records
|
# Update usage records
|
||||||
|
@ -66,42 +70,42 @@ def update_group_members(sender, instance, action, **kwargs):
|
||||||
stripe.SubscriptionItem.create_usage_record(
|
stripe.SubscriptionItem.create_usage_record(
|
||||||
SUBSCRIPTION_ITEM.stripe_id,
|
SUBSCRIPTION_ITEM.stripe_id,
|
||||||
quantity=len(instance.members.all()),
|
quantity=len(instance.members.all()),
|
||||||
action="set"
|
action="set",
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(
|
print(
|
||||||
f'Warning: Unable to update usage record for SubscriptionGroup ID:{instance.id}')
|
f"Warning: Unable to update usage record for SubscriptionGroup ID:{instance.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_migrate)
|
@receiver(post_migrate)
|
||||||
def create_groups(sender, **kwargs):
|
def create_groups(sender, **kwargs):
|
||||||
if sender.name == "agencies":
|
if sender.name == "agencies":
|
||||||
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f:
|
with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
|
||||||
seed_data = json.loads(f.read())
|
seed_data = json.loads(f.read())
|
||||||
for user_group in seed_data['user_groups']:
|
for user_group in seed_data["user_groups"]:
|
||||||
OWNER = CustomUser.objects.filter(
|
OWNER = CustomUser.objects.filter(email=user_group["owner"]).first()
|
||||||
email=user_group['owner']).first()
|
|
||||||
USER_GROUP, CREATED = UserGroup.objects.get_or_create(
|
USER_GROUP, CREATED = UserGroup.objects.get_or_create(
|
||||||
owner=OWNER,
|
owner=OWNER,
|
||||||
agency_name=user_group['name'],
|
agency_name=user_group["name"],
|
||||||
)
|
)
|
||||||
if CREATED:
|
if CREATED:
|
||||||
print(f"Created UserGroup {USER_GROUP.agency_name}")
|
print(f"Created UserGroup {USER_GROUP.agency_name}")
|
||||||
|
|
||||||
# Add managers
|
# Add managers
|
||||||
USERS = CustomUser.objects.filter(
|
USERS = CustomUser.objects.filter(email__in=user_group["managers"])
|
||||||
email__in=user_group['managers'])
|
|
||||||
for USER in USERS:
|
for USER in USERS:
|
||||||
if USER not in USER_GROUP.managers.all():
|
if USER not in USER_GROUP.managers.all():
|
||||||
print(
|
print(
|
||||||
f"Adding User {USER.full_name} as manager to UserGroup {USER_GROUP.agency_name}")
|
f"Adding User {USER.full_name} as manager to UserGroup {USER_GROUP.agency_name}"
|
||||||
|
)
|
||||||
USER_GROUP.managers.add(USER)
|
USER_GROUP.managers.add(USER)
|
||||||
# Add members
|
# Add members
|
||||||
USERS = CustomUser.objects.filter(
|
USERS = CustomUser.objects.filter(email__in=user_group["members"])
|
||||||
email__in=user_group['members'])
|
|
||||||
for USER in USERS:
|
for USER in USERS:
|
||||||
if USER not in USER_GROUP.members.all():
|
if USER not in USER_GROUP.members.all():
|
||||||
print(
|
print(
|
||||||
f"Adding User {USER.full_name} as member to UserGroup {USER_GROUP.agency_name}")
|
f"Adding User {USER.full_name} as member to UserGroup {USER_GROUP.agency_name}"
|
||||||
|
)
|
||||||
USER_GROUP.clients.add(USER)
|
USER_GROUP.clients.add(USER)
|
||||||
USER_GROUP.save()
|
USER_GROUP.save()
|
||||||
|
|
|
@ -2,5 +2,5 @@ from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class EmailsConfig(AppConfig):
|
class EmailsConfig(AppConfig):
|
||||||
default_auto_field = 'django.db.models.BigAutoField'
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
name = 'webdriver'
|
name = "webdriver"
|
||||||
|
|
|
@ -1,11 +1,19 @@
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from webdriver.utils import setup_webdriver, selenium_action_template, google_search, get_element, get_elements
|
from webdriver.utils import (
|
||||||
|
setup_webdriver,
|
||||||
|
selenium_action_template,
|
||||||
|
google_search,
|
||||||
|
get_element,
|
||||||
|
get_elements,
|
||||||
|
)
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
from search_results.tasks import create_search_result
|
from search_results.tasks import create_search_result
|
||||||
|
|
||||||
|
|
||||||
# Task template
|
# Task template
|
||||||
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 5})
|
@shared_task(
|
||||||
|
autoretry_for=(Exception,), retry_kwargs={"max_retries": 3, "countdown": 5}
|
||||||
|
)
|
||||||
def sample_selenium_task():
|
def sample_selenium_task():
|
||||||
|
|
||||||
driver = setup_webdriver(use_proxy=False, use_saved_session=False)
|
driver = setup_webdriver(use_proxy=False, use_saved_session=False)
|
||||||
|
@ -18,27 +26,29 @@ def sample_selenium_task():
|
||||||
driver.close()
|
driver.close()
|
||||||
driver.quit()
|
driver.quit()
|
||||||
|
|
||||||
|
|
||||||
# Sample task to scrape Google for search results based on a keyword
|
# Sample task to scrape Google for search results based on a keyword
|
||||||
|
|
||||||
|
|
||||||
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 5})
|
@shared_task(
|
||||||
|
autoretry_for=(Exception,), retry_kwargs={"max_retries": 3, "countdown": 5}
|
||||||
|
)
|
||||||
def simple_google_search():
|
def simple_google_search():
|
||||||
driver = setup_webdriver(driver_type="firefox",
|
driver = setup_webdriver(
|
||||||
use_proxy=False, use_saved_session=False)
|
driver_type="firefox", use_proxy=False, use_saved_session=False
|
||||||
|
)
|
||||||
driver.get(f"https://google.com/")
|
driver.get(f"https://google.com/")
|
||||||
|
|
||||||
google_search(driver, search_term="cat blog posts")
|
google_search(driver, search_term="cat blog posts")
|
||||||
|
|
||||||
# Count number of Google search results
|
# Count number of Google search results
|
||||||
search_items = get_elements(
|
search_items = get_elements(driver, "xpath", '//*[@id="search"]/div[1]/div[1]/*')
|
||||||
driver, "xpath", '//*[@id="search"]/div[1]/div[1]/*')
|
|
||||||
|
|
||||||
for item in search_items:
|
for item in search_items:
|
||||||
title = item.find_element(By.TAG_NAME, 'h3').text
|
title = item.find_element(By.TAG_NAME, "h3").text
|
||||||
link = item.find_element(By.TAG_NAME, 'a').get_attribute('href')
|
link = item.find_element(By.TAG_NAME, "a").get_attribute("href")
|
||||||
|
|
||||||
create_search_result.apply_async(
|
create_search_result.apply_async(kwargs={"title": title, "link": link})
|
||||||
kwargs={"title": title, "link": link})
|
|
||||||
|
|
||||||
driver.close()
|
driver.close()
|
||||||
driver.quit()
|
driver.quit()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Settings file to hold constants and functions
|
Settings file to hold constants and functions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
from selenium.webdriver.common.keys import Keys
|
from selenium.webdriver.common.keys import Keys
|
||||||
from config.settings import get_secret
|
from config.settings import get_secret
|
||||||
|
@ -18,24 +19,26 @@ import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
def take_snapshot(driver, filename='dump.png'):
|
def take_snapshot(driver, filename="dump.png"):
|
||||||
# Set window size
|
# Set window size
|
||||||
required_width = driver.execute_script(
|
required_width = driver.execute_script(
|
||||||
'return document.body.parentNode.scrollWidth')
|
"return document.body.parentNode.scrollWidth"
|
||||||
|
)
|
||||||
required_height = driver.execute_script(
|
required_height = driver.execute_script(
|
||||||
'return document.body.parentNode.scrollHeight')
|
"return document.body.parentNode.scrollHeight"
|
||||||
driver.set_window_size(
|
)
|
||||||
required_width, required_height+(required_height*0.05))
|
driver.set_window_size(required_width, required_height + (required_height * 0.05))
|
||||||
|
|
||||||
# Take the snapshot
|
# Take the snapshot
|
||||||
driver.find_element(By.TAG_NAME,
|
driver.find_element(By.TAG_NAME, "body").screenshot(
|
||||||
'body').screenshot('/dumps/'+filename) # avoids any scrollbars
|
"/dumps/" + filename
|
||||||
print('Snapshot saved')
|
) # avoids any scrollbars
|
||||||
|
print("Snapshot saved")
|
||||||
|
|
||||||
|
|
||||||
def dump_html(driver, filename='dump.html'):
|
def dump_html(driver, filename="dump.html"):
|
||||||
# Save the page source to error.html
|
# Save the page source to error.html
|
||||||
with open(('/dumps/'+filename), 'w', encoding='utf-8') as file:
|
with open(("/dumps/" + filename), "w", encoding="utf-8") as file:
|
||||||
file.write(driver.page_source)
|
file.write(driver.page_source)
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,83 +47,83 @@ def setup_webdriver(driver_type="chrome", use_proxy=True, use_saved_session=Fals
|
||||||
if not USE_PROXY:
|
if not USE_PROXY:
|
||||||
use_proxy = False
|
use_proxy = False
|
||||||
if use_proxy:
|
if use_proxy:
|
||||||
print('Running driver with proxy enabled')
|
print("Running driver with proxy enabled")
|
||||||
else:
|
else:
|
||||||
print('Running driver with proxy disabled')
|
print("Running driver with proxy disabled")
|
||||||
|
|
||||||
if use_saved_session:
|
if use_saved_session:
|
||||||
print('Running with saved session')
|
print("Running with saved session")
|
||||||
else:
|
else:
|
||||||
print('Running without using saved session')
|
print("Running without using saved session")
|
||||||
|
|
||||||
if driver_type == "chrome":
|
if driver_type == "chrome":
|
||||||
print('Using Chrome driver')
|
print("Using Chrome driver")
|
||||||
opts = uc.ChromeOptions()
|
opts = uc.ChromeOptions()
|
||||||
|
|
||||||
if use_saved_session:
|
if use_saved_session:
|
||||||
if os.path.exists("/tmp_chrome_profile"):
|
if os.path.exists("/tmp_chrome_profile"):
|
||||||
print('Existing Chrome ephemeral profile found')
|
print("Existing Chrome ephemeral profile found")
|
||||||
else:
|
else:
|
||||||
print('No existing Chrome ephemeral profile found')
|
print("No existing Chrome ephemeral profile found")
|
||||||
os.system("mkdir /tmp_chrome_profile")
|
os.system("mkdir /tmp_chrome_profile")
|
||||||
if os.path.exists('/chrome'):
|
if os.path.exists("/chrome"):
|
||||||
print('Copying Chrome Profile to ephemeral directory')
|
print("Copying Chrome Profile to ephemeral directory")
|
||||||
# Flush any non-essential cache directories from the existing profile as they may balloon in size overtime
|
# Flush any non-essential cache directories from the existing profile as they may balloon in size overtime
|
||||||
os.system(
|
os.system('rm -rf "/chrome/Selenium Profile/Code Cache/*"')
|
||||||
'rm -rf "/chrome/Selenium Profile/Code Cache/*"')
|
|
||||||
# Create a copy of the Chrome Profile
|
# Create a copy of the Chrome Profile
|
||||||
os.system("cp -r /chrome/* /tmp_chrome_profile")
|
os.system("cp -r /chrome/* /tmp_chrome_profile")
|
||||||
try:
|
try:
|
||||||
# Remove some items related to file locks
|
# Remove some items related to file locks
|
||||||
os.remove('/tmp_chrome_profile/SingletonLock')
|
os.remove("/tmp_chrome_profile/SingletonLock")
|
||||||
os.remove('/tmp_chrome_profile/SingletonSocket')
|
os.remove("/tmp_chrome_profile/SingletonSocket")
|
||||||
os.remove('/tmp_chrome_profile/SingletonLock')
|
os.remove("/tmp_chrome_profile/SingletonLock")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print('No existing Chrome Profile found. Creating one from scratch')
|
print("No existing Chrome Profile found. Creating one from scratch")
|
||||||
|
|
||||||
if use_saved_session:
|
if use_saved_session:
|
||||||
# Specify the user data directory
|
# Specify the user data directory
|
||||||
opts.add_argument(f'--user-data-dir=/tmp_chrome_profile')
|
opts.add_argument(f"--user-data-dir=/tmp_chrome_profile")
|
||||||
opts.add_argument('--profile-directory=Selenium Profile')
|
opts.add_argument("--profile-directory=Selenium Profile")
|
||||||
|
|
||||||
# Set proxy
|
# Set proxy
|
||||||
if use_proxy:
|
if use_proxy:
|
||||||
opts.add_argument(
|
opts.add_argument(
|
||||||
f'--proxy-server=socks5://{get_secret("PROXY_IP")}:{get_secret("PROXY_PORT_IP_AUTH")}')
|
f'--proxy-server=socks5://{get_secret("PROXY_IP")}:{get_secret("PROXY_PORT_IP_AUTH")}'
|
||||||
|
)
|
||||||
|
|
||||||
opts.add_argument("--disable-extensions")
|
opts.add_argument("--disable-extensions")
|
||||||
opts.add_argument('--disable-application-cache')
|
opts.add_argument("--disable-application-cache")
|
||||||
opts.add_argument("--disable-setuid-sandbox")
|
opts.add_argument("--disable-setuid-sandbox")
|
||||||
opts.add_argument('--disable-dev-shm-usage')
|
opts.add_argument("--disable-dev-shm-usage")
|
||||||
opts.add_argument("--disable-gpu")
|
opts.add_argument("--disable-gpu")
|
||||||
opts.add_argument("--no-sandbox")
|
opts.add_argument("--no-sandbox")
|
||||||
opts.add_argument("--headless=new")
|
opts.add_argument("--headless=new")
|
||||||
driver = uc.Chrome(options=opts)
|
driver = uc.Chrome(options=opts)
|
||||||
|
|
||||||
elif driver_type == "firefox":
|
elif driver_type == "firefox":
|
||||||
print('Using firefox driver')
|
print("Using firefox driver")
|
||||||
opts = FirefoxOptions()
|
opts = FirefoxOptions()
|
||||||
if use_saved_session:
|
if use_saved_session:
|
||||||
if not os.path.exists("/firefox"):
|
if not os.path.exists("/firefox"):
|
||||||
print('No profile found')
|
print("No profile found")
|
||||||
os.makedirs("/firefox")
|
os.makedirs("/firefox")
|
||||||
else:
|
else:
|
||||||
print('Existing profile found')
|
print("Existing profile found")
|
||||||
# Specify a profile if it exists
|
# Specify a profile if it exists
|
||||||
opts.profile = "/firefox"
|
opts.profile = "/firefox"
|
||||||
|
|
||||||
# Set proxy
|
# Set proxy
|
||||||
if use_proxy:
|
if use_proxy:
|
||||||
opts.set_preference('network.proxy.type', 1)
|
opts.set_preference("network.proxy.type", 1)
|
||||||
opts.set_preference('network.proxy.socks',
|
opts.set_preference("network.proxy.socks", get_secret("PROXY_IP"))
|
||||||
get_secret('PROXY_IP'))
|
opts.set_preference(
|
||||||
opts.set_preference('network.proxy.socks_port',
|
"network.proxy.socks_port", int(get_secret("PROXY_PORT_IP_AUTH"))
|
||||||
int(get_secret('PROXY_PORT_IP_AUTH')))
|
)
|
||||||
opts.set_preference('network.proxy.socks_remote_dns', False)
|
opts.set_preference("network.proxy.socks_remote_dns", False)
|
||||||
|
|
||||||
opts.add_argument('--disable-dev-shm-usage')
|
opts.add_argument("--disable-dev-shm-usage")
|
||||||
opts.add_argument("--headless")
|
opts.add_argument("--headless")
|
||||||
opts.add_argument("--disable-gpu")
|
opts.add_argument("--disable-gpu")
|
||||||
driver = webdriver.Firefox(options=opts)
|
driver = webdriver.Firefox(options=opts)
|
||||||
|
@ -128,13 +131,15 @@ def setup_webdriver(driver_type="chrome", use_proxy=True, use_saved_session=Fals
|
||||||
driver.maximize_window()
|
driver.maximize_window()
|
||||||
|
|
||||||
# Check if proxy is working
|
# Check if proxy is working
|
||||||
driver.get('https://api.ipify.org/')
|
driver.get("https://api.ipify.org/")
|
||||||
body = WebDriverWait(driver, 10).until(
|
body = WebDriverWait(driver, 10).until(
|
||||||
EC.presence_of_element_located((By.TAG_NAME, "body")))
|
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||||
|
)
|
||||||
ip_address = body.text
|
ip_address = body.text
|
||||||
print(f'External IP: {ip_address}')
|
print(f"External IP: {ip_address}")
|
||||||
return driver
|
return driver
|
||||||
|
|
||||||
|
|
||||||
# These are wrapper function for quickly automating multiple steps in webscraping (logins, button presses, text inputs, etc.)
|
# These are wrapper function for quickly automating multiple steps in webscraping (logins, button presses, text inputs, etc.)
|
||||||
# Depending on your use case, you may have to opt out of using this
|
# Depending on your use case, you may have to opt out of using this
|
||||||
|
|
||||||
|
@ -151,10 +156,11 @@ def get_element(driver, by, key, hidden_element=False, timeout=8):
|
||||||
wait = WebDriverWait(driver, timeout=timeout)
|
wait = WebDriverWait(driver, timeout=timeout)
|
||||||
if not hidden_element:
|
if not hidden_element:
|
||||||
element = wait.until(
|
element = wait.until(
|
||||||
EC.element_to_be_clickable((by, key)) and EC.visibility_of_element_located((by, key)))
|
EC.element_to_be_clickable((by, key))
|
||||||
|
and EC.visibility_of_element_located((by, key))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
element = wait.until(EC.presence_of_element_located(
|
element = wait.until(EC.presence_of_element_located((by, key)))
|
||||||
(by, key)))
|
|
||||||
return element
|
return element
|
||||||
except Exception:
|
except Exception:
|
||||||
dump_html(driver)
|
dump_html(driver)
|
||||||
|
@ -173,13 +179,12 @@ def get_elements(driver, by, key, hidden_element=False, timeout=8):
|
||||||
wait = WebDriverWait(driver, timeout=timeout)
|
wait = WebDriverWait(driver, timeout=timeout)
|
||||||
|
|
||||||
if hidden_element:
|
if hidden_element:
|
||||||
elements = wait.until(
|
elements = wait.until(EC.presence_of_all_elements_located((by, key)))
|
||||||
EC.presence_of_all_elements_located((by, key)))
|
|
||||||
else:
|
else:
|
||||||
visible_elements = wait.until(
|
visible_elements = wait.until(
|
||||||
EC.visibility_of_any_elements_located((by, key)))
|
EC.visibility_of_any_elements_located((by, key))
|
||||||
elements = [
|
)
|
||||||
element for element in visible_elements if element.is_enabled()]
|
elements = [element for element in visible_elements if element.is_enabled()]
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -193,17 +198,22 @@ def get_elements(driver, by, key, hidden_element=False, timeout=8):
|
||||||
def execute_selenium_elements(driver, timeout, elements):
|
def execute_selenium_elements(driver, timeout, elements):
|
||||||
try:
|
try:
|
||||||
for index, element in enumerate(elements):
|
for index, element in enumerate(elements):
|
||||||
print('Waiting...')
|
print("Waiting...")
|
||||||
# Element may have a keyword specified, check if that exists before running any actions
|
# Element may have a keyword specified, check if that exists before running any actions
|
||||||
if "keyword" in element:
|
if "keyword" in element:
|
||||||
# Skip a step if the keyword does not exist
|
# Skip a step if the keyword does not exist
|
||||||
if element['keyword'] not in driver.page_source:
|
if element["keyword"] not in driver.page_source:
|
||||||
print(
|
print(
|
||||||
f'Keyword {element["keyword"]} does not exist. Skipping step: {index+1} - {element["name"]}')
|
f'Keyword {element["keyword"]} does not exist. Skipping step: {index+1} - {element["name"]}'
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
elif element['keyword'] in driver.page_source and element['type'] == 'skip':
|
elif (
|
||||||
|
element["keyword"] in driver.page_source
|
||||||
|
and element["type"] == "skip"
|
||||||
|
):
|
||||||
print(
|
print(
|
||||||
f'Keyword {element["keyword"]} does exists. Stopping at step: {index+1} - {element["name"]}')
|
f'Keyword {element["keyword"]} does exists. Stopping at step: {index+1} - {element["name"]}'
|
||||||
|
)
|
||||||
break
|
break
|
||||||
print(f'Step: {index+1} - {element["name"]}')
|
print(f'Step: {index+1} - {element["name"]}')
|
||||||
# Revert to default iframe action
|
# Revert to default iframe action
|
||||||
|
@ -217,31 +227,47 @@ def execute_selenium_elements(driver, timeout, elements):
|
||||||
else:
|
else:
|
||||||
values = element["input"]
|
values = element["input"]
|
||||||
if type(values) is list:
|
if type(values) is list:
|
||||||
raise Exception(
|
raise Exception('Invalid input value specified for "callback" type')
|
||||||
'Invalid input value specified for "callback" type')
|
|
||||||
else:
|
else:
|
||||||
# For single input values
|
# For single input values
|
||||||
driver.execute_script(
|
driver.execute_script(f'onRecaptcha("{values}");')
|
||||||
f'onRecaptcha("{values}");')
|
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
# Try to get default element
|
# Try to get default element
|
||||||
if "hidden" in element:
|
if "hidden" in element:
|
||||||
site_element = get_element(
|
site_element = get_element(
|
||||||
driver, element["default"]["type"], element["default"]["key"], hidden_element=True, timeout=timeout)
|
driver,
|
||||||
|
element["default"]["type"],
|
||||||
|
element["default"]["key"],
|
||||||
|
hidden_element=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
site_element = get_element(
|
site_element = get_element(
|
||||||
driver, element["default"]["type"], element["default"]["key"], timeout=timeout)
|
driver,
|
||||||
|
element["default"]["type"],
|
||||||
|
element["default"]["key"],
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to find primary element')
|
print(f"Failed to find primary element")
|
||||||
# If that fails, try to get the failover one
|
# If that fails, try to get the failover one
|
||||||
print('Trying to find legacy element')
|
print("Trying to find legacy element")
|
||||||
if "hidden" in element:
|
if "hidden" in element:
|
||||||
site_element = get_element(
|
site_element = get_element(
|
||||||
driver, element["failover"]["type"], element["failover"]["key"], hidden_element=True, timeout=timeout)
|
driver,
|
||||||
|
element["failover"]["type"],
|
||||||
|
element["failover"]["key"],
|
||||||
|
hidden_element=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
site_element = get_element(
|
site_element = get_element(
|
||||||
driver, element["failover"]["type"], element["failover"]["key"], timeout=timeout)
|
driver,
|
||||||
|
element["failover"]["type"],
|
||||||
|
element["failover"]["key"],
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
# Clicking an element
|
# Clicking an element
|
||||||
if element["type"] == "click":
|
if element["type"] == "click":
|
||||||
site_element.click()
|
site_element.click()
|
||||||
|
@ -272,11 +298,13 @@ def execute_selenium_elements(driver, timeout, elements):
|
||||||
values = element["input"]
|
values = element["input"]
|
||||||
if type(values) is list:
|
if type(values) is list:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Invalid input value specified for "input_replace" type')
|
'Invalid input value specified for "input_replace" type'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# For single input values
|
# For single input values
|
||||||
driver.execute_script(
|
driver.execute_script(
|
||||||
f'arguments[0].value = "{values}";', site_element)
|
f'arguments[0].value = "{values}";', site_element
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
take_snapshot(driver)
|
take_snapshot(driver)
|
||||||
dump_html(driver)
|
dump_html(driver)
|
||||||
|
@ -285,30 +313,33 @@ def execute_selenium_elements(driver, timeout, elements):
|
||||||
raise Exception(e)
|
raise Exception(e)
|
||||||
|
|
||||||
|
|
||||||
def solve_captcha(site_key, url, retry_attempts=3, version='v2', enterprise=False, use_proxy=True):
|
def solve_captcha(
|
||||||
|
site_key, url, retry_attempts=3, version="v2", enterprise=False, use_proxy=True
|
||||||
|
):
|
||||||
# Manual proxy override set via $ENV
|
# Manual proxy override set via $ENV
|
||||||
if not USE_PROXY:
|
if not USE_PROXY:
|
||||||
use_proxy = False
|
use_proxy = False
|
||||||
if CAPTCHA_TESTING:
|
if CAPTCHA_TESTING:
|
||||||
print('Initializing CAPTCHA solver in dummy mode')
|
print("Initializing CAPTCHA solver in dummy mode")
|
||||||
code = random.randint()
|
code = random.randint()
|
||||||
print("CAPTCHA Successful")
|
print("CAPTCHA Successful")
|
||||||
return code
|
return code
|
||||||
|
|
||||||
elif use_proxy:
|
elif use_proxy:
|
||||||
print('Using CAPTCHA solver with proxy')
|
print("Using CAPTCHA solver with proxy")
|
||||||
else:
|
else:
|
||||||
print('Using CAPTCHA solver without proxy')
|
print("Using CAPTCHA solver without proxy")
|
||||||
|
|
||||||
captcha_params = {
|
captcha_params = {
|
||||||
"url": url,
|
"url": url,
|
||||||
"sitekey": site_key,
|
"sitekey": site_key,
|
||||||
"version": version,
|
"version": version,
|
||||||
"enterprise": 1 if enterprise else 0,
|
"enterprise": 1 if enterprise else 0,
|
||||||
"proxy": {
|
"proxy": (
|
||||||
'type': 'socks5',
|
{"type": "socks5", "uri": get_secret("PROXY_USER_AUTH")}
|
||||||
'uri': get_secret('PROXY_USER_AUTH')
|
if use_proxy
|
||||||
} if use_proxy else None
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Keep retrying until max attempts is reached
|
# Keep retrying until max attempts is reached
|
||||||
|
@ -316,12 +347,12 @@ def solve_captcha(site_key, url, retry_attempts=3, version='v2', enterprise=Fals
|
||||||
# Solver uses 2CAPTCHA by default
|
# Solver uses 2CAPTCHA by default
|
||||||
solver = TwoCaptcha(get_secret("CAPTCHA_API_KEY"))
|
solver = TwoCaptcha(get_secret("CAPTCHA_API_KEY"))
|
||||||
try:
|
try:
|
||||||
print('Waiting for CAPTCHA code...')
|
print("Waiting for CAPTCHA code...")
|
||||||
code = solver.recaptcha(**captcha_params)["code"]
|
code = solver.recaptcha(**captcha_params)["code"]
|
||||||
print("CAPTCHA Successful")
|
print("CAPTCHA Successful")
|
||||||
return code
|
return code
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'CAPTCHA Failed! {e}')
|
print(f"CAPTCHA Failed! {e}")
|
||||||
|
|
||||||
raise Exception(f"CAPTCHA API Failed!")
|
raise Exception(f"CAPTCHA API Failed!")
|
||||||
|
|
||||||
|
@ -339,13 +370,12 @@ def save_browser_session(driver):
|
||||||
# Copy over the profile once we finish logging in
|
# Copy over the profile once we finish logging in
|
||||||
if isinstance(driver, webdriver.Firefox):
|
if isinstance(driver, webdriver.Firefox):
|
||||||
# Copy process for Firefox
|
# Copy process for Firefox
|
||||||
print('Updating saved Firefox profile')
|
print("Updating saved Firefox profile")
|
||||||
# Get the current profile directory from about:support page
|
# Get the current profile directory from about:support page
|
||||||
driver.get("about:support")
|
driver.get("about:support")
|
||||||
box = get_element(
|
box = get_element(driver, "id", "profile-dir-box", timeout=4)
|
||||||
driver, "id", "profile-dir-box", timeout=4)
|
|
||||||
temp_profile_path = os.path.join(os.getcwd(), box.text)
|
temp_profile_path = os.path.join(os.getcwd(), box.text)
|
||||||
profile_path = '/firefox'
|
profile_path = "/firefox"
|
||||||
# Create the command
|
# Create the command
|
||||||
copy_command = "cp -r " + temp_profile_path + "/* " + profile_path
|
copy_command = "cp -r " + temp_profile_path + "/* " + profile_path
|
||||||
# Copy over the Firefox profile
|
# Copy over the Firefox profile
|
||||||
|
@ -353,13 +383,13 @@ def save_browser_session(driver):
|
||||||
print("Firefox profile saved")
|
print("Firefox profile saved")
|
||||||
elif isinstance(driver, uc.Chrome):
|
elif isinstance(driver, uc.Chrome):
|
||||||
# Copy the Chrome profile
|
# Copy the Chrome profile
|
||||||
print('Updating non-ephemeral Chrome profile')
|
print("Updating non-ephemeral Chrome profile")
|
||||||
# Flush Code Cache again to speed up copy
|
# Flush Code Cache again to speed up copy
|
||||||
os.system(
|
os.system('rm -rf "/tmp_chrome_profile/SimpleDMCA Profile/Code Cache/*"')
|
||||||
'rm -rf "/tmp_chrome_profile/SimpleDMCA Profile/Code Cache/*"')
|
|
||||||
if os.system("cp -r /tmp_chrome_profile/* /chrome"):
|
if os.system("cp -r /tmp_chrome_profile/* /chrome"):
|
||||||
print("Chrome profile saved")
|
print("Chrome profile saved")
|
||||||
|
|
||||||
|
|
||||||
# Sample function
|
# Sample function
|
||||||
# Call this within a Celery task
|
# Call this within a Celery task
|
||||||
# TODO: Modify as needed to your needs
|
# TODO: Modify as needed to your needs
|
||||||
|
@ -370,7 +400,7 @@ def selenium_action_template(driver):
|
||||||
info = {
|
info = {
|
||||||
"sample_field1": "sample_data",
|
"sample_field1": "sample_data",
|
||||||
"sample_field2": "sample_data",
|
"sample_field2": "sample_data",
|
||||||
"captcha_code": lambda: solve_captcha('SITE_KEY', 'SITE_URL')
|
"captcha_code": lambda: solve_captcha("SITE_KEY", "SITE_URL"),
|
||||||
}
|
}
|
||||||
|
|
||||||
elements = [
|
elements = [
|
||||||
|
@ -382,13 +412,10 @@ def selenium_action_template(driver):
|
||||||
"default": {
|
"default": {
|
||||||
# See get_element() for possible selector types
|
# See get_element() for possible selector types
|
||||||
"type": "xpath",
|
"type": "xpath",
|
||||||
"key": ''
|
"key": "",
|
||||||
},
|
},
|
||||||
# If a site implements canary design releases, you can place the ID for the element in the old design here
|
# If a site implements canary design releases, you can place the ID for the element in the old design here
|
||||||
"failover": {
|
"failover": {"type": "xpath", "key": ""},
|
||||||
"type": "xpath",
|
|
||||||
"key": ''
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -398,8 +425,8 @@ def selenium_action_template(driver):
|
||||||
|
|
||||||
# Fill in final fstring values in elements
|
# Fill in final fstring values in elements
|
||||||
for element in elements:
|
for element in elements:
|
||||||
if 'input' in element and '{' in element['input']:
|
if "input" in element and "{" in element["input"]:
|
||||||
a = element['input'].strip('{}')
|
a = element["input"].strip("{}")
|
||||||
if a in info:
|
if a in info:
|
||||||
value = info[a]
|
value = info[a]
|
||||||
# Check if the value is a callable (a lambda function) and call it if so
|
# Check if the value is a callable (a lambda function) and call it if so
|
||||||
|
@ -411,11 +438,12 @@ def selenium_action_template(driver):
|
||||||
# Use the stored value
|
# Use the stored value
|
||||||
value = site_form_values[a]
|
value = site_form_values[a]
|
||||||
# Replace the placeholder with the actual value
|
# Replace the placeholder with the actual value
|
||||||
element['input'] = str(value)
|
element["input"] = str(value)
|
||||||
|
|
||||||
# Execute the selenium actions
|
# Execute the selenium actions
|
||||||
execute_selenium_elements(driver, 8, elements)
|
execute_selenium_elements(driver, 8, elements)
|
||||||
|
|
||||||
|
|
||||||
# Sample task for Google search
|
# Sample task for Google search
|
||||||
|
|
||||||
|
|
||||||
|
@ -429,40 +457,28 @@ def google_search(driver, search_term):
|
||||||
"name": "Type in search term",
|
"name": "Type in search term",
|
||||||
"type": "input",
|
"type": "input",
|
||||||
"input": "{search_term}",
|
"input": "{search_term}",
|
||||||
"default": {
|
"default": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
|
||||||
"type": "xpath",
|
"failover": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
|
||||||
"key": '//*[@id="APjFqb"]'
|
|
||||||
},
|
|
||||||
"failover": {
|
|
||||||
"type": "xpath",
|
|
||||||
"key": '//*[@id="APjFqb"]'
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Press enter",
|
"name": "Press enter",
|
||||||
"type": "input_enter",
|
"type": "input_enter",
|
||||||
"default": {
|
"default": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
|
||||||
"type": "xpath",
|
"failover": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
|
||||||
"key": '//*[@id="APjFqb"]'
|
|
||||||
},
|
|
||||||
"failover": {
|
|
||||||
"type": "xpath",
|
|
||||||
"key": '//*[@id="APjFqb"]'
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
site_form_values = {}
|
site_form_values = {}
|
||||||
|
|
||||||
for element in elements:
|
for element in elements:
|
||||||
if 'input' in element and '{' in element['input']:
|
if "input" in element and "{" in element["input"]:
|
||||||
a = element['input'].strip('{}')
|
a = element["input"].strip("{}")
|
||||||
if a in info:
|
if a in info:
|
||||||
value = info[a]
|
value = info[a]
|
||||||
if callable(value):
|
if callable(value):
|
||||||
if a not in site_form_values:
|
if a not in site_form_values:
|
||||||
site_form_values[a] = value()
|
site_form_values[a] = value()
|
||||||
value = site_form_values[a]
|
value = site_form_values[a]
|
||||||
element['input'] = str(value)
|
element["input"] = str(value)
|
||||||
|
|
||||||
execute_selenium_elements(driver, 8, elements)
|
execute_selenium_elements(driver, 8, elements)
|
||||||
|
|
|
@ -5,7 +5,7 @@ services:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: Dockerfile
|
dockerfile: Dockerfile
|
||||||
image: drf_template:latest
|
image: drf_template
|
||||||
ports:
|
ports:
|
||||||
- "${BACKEND_PORT}:${BACKEND_PORT}"
|
- "${BACKEND_PORT}:${BACKEND_PORT}"
|
||||||
environment:
|
environment:
|
||||||
|
@ -23,7 +23,7 @@ services:
|
||||||
env_file: .env
|
env_file: .env
|
||||||
environment:
|
environment:
|
||||||
- RUN_TYPE=worker
|
- RUN_TYPE=worker
|
||||||
image: drf_template:latest
|
image: drf_template
|
||||||
volumes:
|
volumes:
|
||||||
- .:/code
|
- .:/code
|
||||||
- ./chrome:/chrome
|
- ./chrome:/chrome
|
||||||
|
@ -42,7 +42,7 @@ services:
|
||||||
env_file: .env
|
env_file: .env
|
||||||
environment:
|
environment:
|
||||||
- RUN_TYPE=beat
|
- RUN_TYPE=beat
|
||||||
image: drf_template:latest
|
image: drf_template
|
||||||
volumes:
|
volumes:
|
||||||
- .:/code
|
- .:/code
|
||||||
depends_on:
|
depends_on:
|
||||||
|
@ -58,7 +58,7 @@ services:
|
||||||
env_file: .env
|
env_file: .env
|
||||||
environment:
|
environment:
|
||||||
- RUN_TYPE=monitor
|
- RUN_TYPE=monitor
|
||||||
image: drf_template:latest
|
image: drf_template
|
||||||
ports:
|
ports:
|
||||||
- "${CELERY_FLOWER_PORT}:5555"
|
- "${CELERY_FLOWER_PORT}:5555"
|
||||||
volumes:
|
volumes:
|
||||||
|
|
Loading…
Reference in a new issue